from ui.components.plot_figure import PlotFrame import math from abc import ABC from matplotlib.figure import Figure from neural_net.activation_layers.activation_layer import ActivationLayer from neural_net.neural_net import NeuralNet from ui.plotters.plotter import Plotter from utils.matplotlib.utils import mpl_matshow class LayerWeightsPlot(PlotFrame): def __init__(self, parent, neural_net: NeuralNet, layer: ActivationLayer, rows, cols): super().__init__(parent) self.plotter = LayerWeightsPlotter(self.figure, neural_net, layer, rows, cols) class LayerWeightsPlotter(Plotter, ABC): def __init__(self, figure: Figure, neural_net: NeuralNet, layer: ActivationLayer, rows, columns): super().__init__(figure) self.neural_net = neural_net self.layer = layer self.axes = figure.subplots(nrows=rows, ncols=columns, squeeze=True, gridspec_kw={'wspace': 0.05, 'hspace': 0.05}) def reset_plot(self): for axes in self.axes: for ax in axes: ax.clear() def plot(self, data): weights = self.layer.weights.T n_neurons = weights.shape[0] n_pixels = weights.shape[1] for i in range(n_neurons): row = i // self.axes.shape[1] col = i % self.axes.shape[1] mpl_matshow(self.axes[row, col], weights[i], int(math.sqrt(n_pixels)))