|
- from ui.components.plot_figure import PlotFrame
- import math
- from abc import ABC
-
- from matplotlib.figure import Figure
-
- from neural_net.activation_layers.activation_layer import ActivationLayer
- from neural_net.neural_net import NeuralNet
- from ui.plotters.plotter import Plotter
- from utils.matplotlib.utils import mpl_matshow
-
- class LayerWeightsPlot(PlotFrame):
- def __init__(self, parent, neural_net: NeuralNet, layer: ActivationLayer, rows, cols):
- super().__init__(parent)
- self.plotter = LayerWeightsPlotter(self.figure, neural_net, layer, rows, cols)
-
- class LayerWeightsPlotter(Plotter, ABC):
- def __init__(self, figure: Figure, neural_net: NeuralNet, layer: ActivationLayer, rows, columns):
- super().__init__(figure)
- self.neural_net = neural_net
- self.layer = layer
- self.axes = figure.subplots(nrows=rows, ncols=columns, squeeze=True,
- gridspec_kw={'wspace': 0.05, 'hspace': 0.05})
-
- def reset_plot(self):
- for axes in self.axes:
- for ax in axes:
- ax.clear()
-
- def plot(self, data):
- weights = self.layer.weights.T
- n_neurons = weights.shape[0]
- n_pixels = weights.shape[1]
- for i in range(n_neurons):
- row = i // self.axes.shape[1]
- col = i % self.axes.shape[1]
- mpl_matshow(self.axes[row, col], weights[i], int(math.sqrt(n_pixels)))
-
|