您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

40 行
1.3KB

  1. from neural_net.trainer import NeuralNetTrainer
  2. from ui.components.plot_figure import PlotFrame
  3. from abc import ABC
  4. from matplotlib.figure import Figure
  5. from neural_net.neural_net import NeuralNet
  6. from ui.plotters.plotter import Plotter
  7. class LossPlot(PlotFrame):
  8. def __init__(self, parent, neural_net: NeuralNet, trainer: NeuralNetTrainer):
  9. super().__init__(parent)
  10. self.plotter = LossPlotter(self.figure, neural_net, trainer)
  11. class LossPlotter(Plotter, ABC):
  12. def __init__(self, figure: Figure, neural_net: NeuralNet, trainer: NeuralNetTrainer):
  13. super().__init__(figure)
  14. self.neural_net = neural_net
  15. self.trainer = trainer
  16. self.axes = figure.add_subplot()
  17. def reset_plot(self):
  18. self.axes.clear()
  19. self.axes.set_title('Loss')
  20. self.axes.set_ylabel("Loss")
  21. self.axes.set_xlabel("Epoch")
  22. def plot(self, data):
  23. losses = []
  24. for epoch in self.trainer.epoch_history:
  25. if epoch.finished:
  26. losses.append(epoch.loss)
  27. self.axes.plot(losses, marker='o', label=f"Loss")
  28. for idx, loss in enumerate(losses):
  29. self.axes.annotate(f"{loss:.4f}", xy=(idx, loss), rotation=45)
  30. self.axes.legend()
  31. self.axes.grid(True)