Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

40 Zeilen
1.4KB

  1. from ui.components.plot_figure import PlotFrame
  2. import math
  3. from abc import ABC
  4. from matplotlib.figure import Figure
  5. from neural_net.activation_layers.activation_layer import ActivationLayer
  6. from neural_net.neural_net import NeuralNet
  7. from ui.plotters.plotter import Plotter
  8. from utils.matplotlib.utils import mpl_matshow
  9. class LayerWeightsPlot(PlotFrame):
  10. def __init__(self, parent, neural_net: NeuralNet, layer: ActivationLayer, rows, cols):
  11. super().__init__(parent)
  12. self.plotter = LayerWeightsPlotter(self.figure, neural_net, layer, rows, cols)
  13. class LayerWeightsPlotter(Plotter, ABC):
  14. def __init__(self, figure: Figure, neural_net: NeuralNet, layer: ActivationLayer, rows, columns):
  15. super().__init__(figure)
  16. self.neural_net = neural_net
  17. self.layer = layer
  18. self.axes = figure.subplots(nrows=rows, ncols=columns, squeeze=True,
  19. gridspec_kw={'wspace': 0.05, 'hspace': 0.05})
  20. def reset_plot(self):
  21. for axes in self.axes:
  22. for ax in axes:
  23. ax.clear()
  24. def plot(self, data):
  25. weights = self.layer.weights.T
  26. n_neurons = weights.shape[0]
  27. n_pixels = weights.shape[1]
  28. for i in range(n_neurons):
  29. row = i // self.axes.shape[1]
  30. col = i % self.axes.shape[1]
  31. mpl_matshow(self.axes[row, col], weights[i], int(math.sqrt(n_pixels)))