import threading import tkinter as tk from neural_net.trainer import NeuralNetTrainer from ui.app_state import AppState from ui.components.number_slider import NumberSlider from ui.training_page.training_page import EpochInformation class TrainingSection(tk.LabelFrame): def __init__(self, parent, app_state: AppState, on_update_neural_net_info): super().__init__(parent, text="Model training") self.app_state = app_state self.on_update_neural_net_info = on_update_neural_net_info self.batch_size = tk.IntVar() self.batch_size.set(1000) self.batch_size_slider = None self.learning_rate = tk.DoubleVar() self.learning_rate.set(0.0001) self.learning_rate_slider = None self.btn_start_stop = None self.stop_button = None self.training_information_container: EpochInformation = None self.trainer: NeuralNetTrainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data, self.learning_rate.get(), self.batch_size.get()) self.create_ui() def create_ui(self): tk.Label(self, text="Batch size:").grid(column=0, row=0, padx=10, pady=5, sticky='w') self.batch_size_slider = NumberSlider(self, self.batch_size, from_=100, to=10000, resolution=1) self.batch_size_slider.grid(column=1, row=0, padx=10, pady=5, sticky='w') tk.Label(self, text="Learning rate:").grid(column=0, row=1, padx=10, pady=5, sticky='w') self.learning_rate_slider = NumberSlider(self, self.learning_rate, from_=0.0001, to=0.1, resolution=0.0001) self.learning_rate_slider.grid(column=1, row=1, padx=10, pady=5, sticky='w') self.btn_prev_epoch = tk.Button(self, text="<<", command=self.on_prev_epoch) self.btn_prev_epoch.grid(column=0, row=2, padx=10, pady=10, sticky='w') self.btn_start_stop = tk.Button(self, text="Start", command=self.toggle_state) self.btn_start_stop.grid(column=1, row=2, padx=10, pady=10, sticky='w') self.btn_next_epoch = tk.Button(self, text=">>", command=self.on_next_epoch) self.btn_next_epoch.grid(column=2, row=2, padx=10, pady=10, sticky='w') def update(self): if self.trainer.is_running: if self.training_information_container is None: self.training_information_container = EpochInformation(self, self.trainer.epoch_history[-1]) self.training_information_container.grid(column=0, row=5, padx=10, pady=10, sticky='e') self.btn_start_stop.config(text="Stop") else: print("Setting the epoch") self.training_information_container.set_epoch(self.trainer.epoch_history[-1]) else: self.btn_start_stop.config(text="Start") def toggle_state(self): if self.trainer.is_running: self.trainer.stop() else: self.thread = threading.Thread(target=self.trainer.start, args=(self.on_epoch_finish, self.on_update_neural_net_info)) self.thread.start() # self.trainer.start(self.on_epoch_finish, self.on_update_neural_net_info) self.update() def start(self): self.thread = threading.Thread(target=self.trainer.start) self.thread.start() # self.trainer.start(on_epoch_finished=self.update_training_data) def on_epoch_finish(self, epoch): print("Updating the epoch") self.update() def on_prev_epoch(self): pass def on_next_epoch(self): pass