Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

41 lines
1.6KB

  1. from abc import ABC
  2. from matplotlib.figure import Figure
  3. from neural_net.epoch import Epoch
  4. from neural_net.neural_net import NeuralNet
  5. from ui.components.plot_figure import PlotFrame
  6. from ui.plotters.plotter import Plotter
  7. class GradientsPlot(PlotFrame):
  8. def __init__(self, parent, neural_net: NeuralNet):
  9. super().__init__(parent)
  10. self.plotter = GradientsPlotter(self.figure, neural_net)
  11. class GradientsPlotter(Plotter, ABC):
  12. def __init__(self, figure: Figure, neural_net: NeuralNet):
  13. super().__init__(figure)
  14. self.neural_net = neural_net
  15. self.axes = figure.subplots(1, 2)
  16. def reset_plot(self):
  17. self.axes[0].clear()
  18. self.axes[0].set_xlabel('Neuron Index')
  19. self.axes[0].set_ylabel('Input Index')
  20. self.axes[1].clear()
  21. self.axes[1].set_xlabel('Output Neuron Index')
  22. self.axes[1].set_ylabel('Hidden Neuron Index')
  23. def plot(self, data: Epoch):
  24. gradients_layer1 = data.layer_dl_gradients[1][-1]
  25. self.axes[0].imshow(gradients_layer1, cmap='coolwarm', aspect='auto')
  26. gradients_layer2 = data.layer_dl_gradients[0][-1]
  27. self.axes[1].imshow(gradients_layer2, cmap='coolwarm', aspect='auto')
  28. def plot_gradients_histogram(self, current_epoch: Epoch):
  29. gradients_layer1 = current_epoch.layer_dl_gradients[1][-1]
  30. self.axes[0].hist(gradients_layer1.flatten(), bins=50, color='blue', alpha=0.7)
  31. gradients_layer2 = current_epoch.layer_dl_gradients[0][-1]
  32. self.axes[1].hist(gradients_layer2.flatten(), bins=50, color='green', alpha=0.7)