from abc import ABC from matplotlib.figure import Figure from neural_net.epoch import Epoch from neural_net.neural_net import NeuralNet from ui.components.plot_figure import PlotFrame from ui.plotters.plotter import Plotter class GradientsPlot(PlotFrame): def __init__(self, parent, neural_net: NeuralNet): super().__init__(parent) self.plotter = GradientsPlotter(self.figure, neural_net) class GradientsPlotter(Plotter, ABC): def __init__(self, figure: Figure, neural_net: NeuralNet): super().__init__(figure) self.neural_net = neural_net self.axes = figure.subplots(1, 2) def reset_plot(self): self.axes[0].clear() self.axes[0].set_xlabel('Neuron Index') self.axes[0].set_ylabel('Input Index') self.axes[1].clear() self.axes[1].set_xlabel('Output Neuron Index') self.axes[1].set_ylabel('Hidden Neuron Index') def plot(self, data: Epoch): gradients_layer1 = data.layer_dl_gradients[1][-1] self.axes[0].imshow(gradients_layer1, cmap='coolwarm', aspect='auto') gradients_layer2 = data.layer_dl_gradients[0][-1] self.axes[1].imshow(gradients_layer2, cmap='coolwarm', aspect='auto') def plot_gradients_histogram(self, current_epoch: Epoch): gradients_layer1 = current_epoch.layer_dl_gradients[1][-1] self.axes[0].hist(gradients_layer1.flatten(), bins=50, color='blue', alpha=0.7) gradients_layer2 = current_epoch.layer_dl_gradients[0][-1] self.axes[1].hist(gradients_layer2.flatten(), bins=50, color='green', alpha=0.7)