from neural_net.trainer import NeuralNetTrainer from ui.components.plot_figure import PlotFrame from abc import ABC from matplotlib.figure import Figure from neural_net.neural_net import NeuralNet from ui.plotters.plotter import Plotter class LossPlot(PlotFrame): def __init__(self, parent, neural_net: NeuralNet, trainer: NeuralNetTrainer): super().__init__(parent) self.plotter = LossPlotter(self.figure, neural_net, trainer) class LossPlotter(Plotter, ABC): def __init__(self, figure: Figure, neural_net: NeuralNet, trainer: NeuralNetTrainer): super().__init__(figure) self.neural_net = neural_net self.trainer = trainer self.axes = figure.add_subplot() def reset_plot(self): self.axes.clear() self.axes.set_title('Loss') self.axes.set_ylabel("Loss") self.axes.set_xlabel("Epoch") def plot(self, data): losses = [] for epoch in self.trainer.epoch_history: if epoch.finished: losses.append(epoch.loss) self.axes.plot(losses, marker='o', label=f"Loss") for idx, loss in enumerate(losses): self.axes.annotate(f"{loss:.4f}", xy=(idx, loss), rotation=45) self.axes.legend() self.axes.grid(True)