import threading import tkinter as tk from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure from neural_net.epoch import Epoch from neural_net.trainer import NeuralNetTrainer from ui.app_state import AppState from ui.front_page.plots.gradients import GradientsPlot from ui.front_page.plots.layer_weights import LayerWeightsPlot from ui.front_page.plots.loss import LossPlot from ui.front_page.sections.training_information import EpochInformation class TrainingPage(tk.Frame): def __init__(self, parent, app_state: AppState, on_training_finished=None): super().__init__(parent) self.app_state = app_state self.on_training_finished = on_training_finished self.trainer: NeuralNetTrainer = None # trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data, learning_rate, nr_epochs) # self.app_state.trainers.append(trainer) # self.trainer = trainer self.create_ui() def start(self, learning_rate, nr_epochs, batch_size, callback=None): if self.trainer is not None: self.trainer.stop() self.trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data, learning_rate=learning_rate, nr_epochs=nr_epochs, batch_size=batch_size, on_epoch_callback=self.update_training_data, on_finished_callback=self.on_training_finished) self.trainer.on_epoch_callback = self.update_training_data self.thread = threading.Thread(target=self.trainer.start) self.thread.start() # self.trainer.start(on_epoch_finished=self.update_training_data) if callback is not None: callback() def update_training_data(self, training_run, data: Epoch): print(f"Updating training data {data.epoch}") if self.trainer.is_running: self.training_information_container.update_training_data(training_run, data) self.loss_plot.update_training_data(training_run, data) if data.epoch % 5 == 0: self.gradients_plot.update_training_data(training_run, data) self.layer0_weights_plot.update_training_data(training_run, data) self.layer1_weights_plot.update_training_data(training_run, data) def create_ui(self): # Training center self.training_information_container = EpochInformation(self, self.app_state.neural_net, self.trainer) self.training_information_container.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10, ipadx=10) actions_frame = tk.Frame(self) actions_frame.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10, ipadx=10) btn_text = "Pause" self.btn_toggle_pause = tk.Button(actions_frame, text=btn_text, command=self.toggle_state) self.btn_toggle_pause.pack(side=tk.LEFT) btn_stop = tk.Button(actions_frame, text="Stop", command=self.trainer.stop) btn_stop.pack(side=tk.LEFT) # Plot tabs plot_tab_control = tk.Notebook(self) plot_tab_control.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, pady=0, padx=0, ipady=0, ipadx=0) self.loss_plot = LossPlot(plot_tab_control, self.app_state.neural_net) plot_tab_control.add(self.loss_plot, text="Loss Function") self.gradients_plot = GradientsPlot(plot_tab_control, self.app_state.neural_net) plot_tab_control.add(self.gradients_plot, text="Gradients") self.layer0_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net, self.app_state.neural_net.layers[0], 11, 11) plot_tab_control.add(self.layer0_weights_plot, text="Weights layer 0") self.layer1_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net, self.app_state.neural_net.layers[1], 2, 5) plot_tab_control.add(self.layer1_weights_plot, text="Weights layer 1") def toggle_state(self): self.trainer.toggle_state() if self.trainer.training_paused: self.btn_toggle_pause.config(text="Resume") else: self.btn_toggle_pause.config(text="Pause") @staticmethod def create_plot_figure(tab): figure = Figure() # Create a matplotlib canvas to display the plot canvas = FigureCanvasTkAgg(figure, tab) canvas.draw() canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) return figure