|
- from abc import ABC
-
- from matplotlib.figure import Figure
-
- from neural_net.epoch import Epoch
- from neural_net.neural_net import NeuralNet
- from ui.components.plot_figure import PlotFrame
- from ui.plotters.plotter import Plotter
-
- class GradientsPlot(PlotFrame):
- def __init__(self, parent, neural_net: NeuralNet):
- super().__init__(parent)
- self.plotter = GradientsPlotter(self.figure, neural_net)
-
- class GradientsPlotter(Plotter, ABC):
- def __init__(self, figure: Figure, neural_net: NeuralNet):
- super().__init__(figure)
- self.neural_net = neural_net
- self.axes = figure.subplots(1, 2)
-
- def reset_plot(self):
- self.axes[0].clear()
- self.axes[0].set_xlabel('Neuron Index')
- self.axes[0].set_ylabel('Input Index')
- self.axes[1].clear()
- self.axes[1].set_xlabel('Output Neuron Index')
- self.axes[1].set_ylabel('Hidden Neuron Index')
-
- def plot(self, data: Epoch):
- gradients_layer1 = data.layer_dl_gradients[1][-1]
- self.axes[0].imshow(gradients_layer1, cmap='coolwarm', aspect='auto')
-
- gradients_layer2 = data.layer_dl_gradients[0][-1]
- self.axes[1].imshow(gradients_layer2, cmap='coolwarm', aspect='auto')
-
- def plot_gradients_histogram(self, current_epoch: Epoch):
- gradients_layer1 = current_epoch.layer_dl_gradients[1][-1]
- self.axes[0].hist(gradients_layer1.flatten(), bins=50, color='blue', alpha=0.7)
-
- gradients_layer2 = current_epoch.layer_dl_gradients[0][-1]
- self.axes[1].hist(gradients_layer2.flatten(), bins=50, color='green', alpha=0.7)
|