|
- from neural_net.trainer import NeuralNetTrainer
- from ui.components.plot_figure import PlotFrame
-
- from abc import ABC
-
- from matplotlib.figure import Figure
-
- from neural_net.neural_net import NeuralNet
- from ui.plotters.plotter import Plotter
-
- class LossPlot(PlotFrame):
- def __init__(self, parent, neural_net: NeuralNet, trainer: NeuralNetTrainer):
- super().__init__(parent)
- self.plotter = LossPlotter(self.figure, neural_net, trainer)
-
- class LossPlotter(Plotter, ABC):
- def __init__(self, figure: Figure, neural_net: NeuralNet, trainer: NeuralNetTrainer):
- super().__init__(figure)
- self.neural_net = neural_net
- self.trainer = trainer
- self.axes = figure.add_subplot()
-
- def reset_plot(self):
- self.axes.clear()
- self.axes.set_title('Loss')
- self.axes.set_ylabel("Loss")
- self.axes.set_xlabel("Epoch")
-
- def plot(self, data):
- losses = []
- for epoch in self.trainer.epoch_history:
- if epoch.finished:
- losses.append(epoch.loss)
-
- self.axes.plot(losses, marker='o', label=f"Loss")
- for idx, loss in enumerate(losses):
- self.axes.annotate(f"{loss:.4f}", xy=(idx, loss), rotation=45)
-
- self.axes.legend()
- self.axes.grid(True)
|