|
- import threading
- import tkinter as tk
-
- from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
- from matplotlib.figure import Figure
-
- from neural_net.epoch import Epoch
- from neural_net.trainer import NeuralNetTrainer
- from ui.app_state import AppState
- from ui.front_page.plots.gradients import GradientsPlot
- from ui.front_page.plots.layer_weights import LayerWeightsPlot
- from ui.front_page.plots.loss import LossPlot
- from ui.front_page.sections.training_information import EpochInformation
-
-
- class TrainingPage(tk.Frame):
- def __init__(self, parent, app_state: AppState, on_training_finished=None):
- super().__init__(parent)
- self.app_state = app_state
- self.on_training_finished = on_training_finished
- self.trainer: NeuralNetTrainer = None
- # trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data, learning_rate, nr_epochs)
- # self.app_state.trainers.append(trainer)
- # self.trainer = trainer
- self.create_ui()
-
- def start(self, learning_rate, nr_epochs, batch_size, callback=None):
- if self.trainer is not None:
- self.trainer.stop()
- self.trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data,
- learning_rate=learning_rate, nr_epochs=nr_epochs, batch_size=batch_size,
- on_epoch_callback=self.update_training_data,
- on_finished_callback=self.on_training_finished)
- self.trainer.on_epoch_callback = self.update_training_data
- self.thread = threading.Thread(target=self.trainer.start)
- self.thread.start()
- # self.trainer.start(on_epoch_finished=self.update_training_data)
- if callback is not None:
- callback()
-
- def update_training_data(self, training_run, data: Epoch):
- print(f"Updating training data {data.epoch}")
- if self.trainer.is_running:
- self.training_information_container.update_training_data(training_run, data)
-
- self.loss_plot.update_training_data(training_run, data)
- if data.epoch % 5 == 0:
- self.gradients_plot.update_training_data(training_run, data)
- self.layer0_weights_plot.update_training_data(training_run, data)
- self.layer1_weights_plot.update_training_data(training_run, data)
-
- def create_ui(self):
- # Training center
- self.training_information_container = EpochInformation(self, self.app_state.neural_net, self.trainer)
- self.training_information_container.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10,
- ipadx=10)
-
- actions_frame = tk.Frame(self)
- actions_frame.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10, ipadx=10)
- btn_text = "Pause"
- self.btn_toggle_pause = tk.Button(actions_frame, text=btn_text, command=self.toggle_state)
- self.btn_toggle_pause.pack(side=tk.LEFT)
- btn_stop = tk.Button(actions_frame, text="Stop", command=self.trainer.stop)
- btn_stop.pack(side=tk.LEFT)
-
- # Plot tabs
- plot_tab_control = tk.Notebook(self)
- plot_tab_control.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, pady=0, padx=0, ipady=0, ipadx=0)
-
- self.loss_plot = LossPlot(plot_tab_control, self.app_state.neural_net)
- plot_tab_control.add(self.loss_plot, text="Loss Function")
-
- self.gradients_plot = GradientsPlot(plot_tab_control, self.app_state.neural_net)
- plot_tab_control.add(self.gradients_plot, text="Gradients")
-
- self.layer0_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net,
- self.app_state.neural_net.layers[0],
- 11, 11)
- plot_tab_control.add(self.layer0_weights_plot, text="Weights layer 0")
-
- self.layer1_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net,
- self.app_state.neural_net.layers[1],
- 2, 5)
- plot_tab_control.add(self.layer1_weights_plot, text="Weights layer 1")
-
- def toggle_state(self):
- self.trainer.toggle_state()
- if self.trainer.training_paused:
- self.btn_toggle_pause.config(text="Resume")
- else:
- self.btn_toggle_pause.config(text="Pause")
-
- @staticmethod
- def create_plot_figure(tab):
- figure = Figure()
-
- # Create a matplotlib canvas to display the plot
- canvas = FigureCanvasTkAgg(figure, tab)
- canvas.draw()
- canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
-
- return figure
|