|
- import threading
- import tkinter as tk
-
- from neural_net.trainer import NeuralNetTrainer
- from ui.app_state import AppState
- from ui.components.number_slider import NumberSlider
- from ui.training_page.training_page import EpochInformation
-
-
- class TrainingSection(tk.LabelFrame):
- def __init__(self, parent, app_state: AppState, on_update_neural_net_info):
- super().__init__(parent, text="Model training")
- self.app_state = app_state
- self.on_update_neural_net_info = on_update_neural_net_info
-
- self.batch_size = tk.IntVar()
- self.batch_size.set(1000)
- self.batch_size_slider = None
- self.learning_rate = tk.DoubleVar()
- self.learning_rate.set(0.0001)
- self.learning_rate_slider = None
- self.btn_start_stop = None
- self.stop_button = None
- self.training_information_container: EpochInformation = None
- self.trainer: NeuralNetTrainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data,
- self.learning_rate.get(), self.batch_size.get())
- self.create_ui()
-
- def create_ui(self):
- tk.Label(self, text="Batch size:").grid(column=0, row=0, padx=10, pady=5, sticky='w')
-
- self.batch_size_slider = NumberSlider(self, self.batch_size, from_=100, to=10000, resolution=1)
- self.batch_size_slider.grid(column=1, row=0, padx=10, pady=5, sticky='w')
-
- tk.Label(self, text="Learning rate:").grid(column=0, row=1, padx=10, pady=5, sticky='w')
- self.learning_rate_slider = NumberSlider(self, self.learning_rate, from_=0.0001, to=0.1, resolution=0.0001)
- self.learning_rate_slider.grid(column=1, row=1, padx=10, pady=5, sticky='w')
-
- self.btn_prev_epoch = tk.Button(self, text="<<", command=self.on_prev_epoch)
- self.btn_prev_epoch.grid(column=0, row=2, padx=10, pady=10, sticky='w')
- self.btn_start_stop = tk.Button(self, text="Start", command=self.toggle_state)
- self.btn_start_stop.grid(column=1, row=2, padx=10, pady=10, sticky='w')
- self.btn_next_epoch = tk.Button(self, text=">>", command=self.on_next_epoch)
- self.btn_next_epoch.grid(column=2, row=2, padx=10, pady=10, sticky='w')
-
- def update(self):
- if self.trainer.is_running:
- if self.training_information_container is None:
- self.training_information_container = EpochInformation(self, self.trainer.epoch_history[-1])
- self.training_information_container.grid(column=0, row=5, padx=10, pady=10, sticky='e')
- self.btn_start_stop.config(text="Stop")
- else:
- print("Setting the epoch")
- self.training_information_container.set_epoch(self.trainer.epoch_history[-1])
- else:
- self.btn_start_stop.config(text="Start")
-
- def toggle_state(self):
- if self.trainer.is_running:
- self.trainer.stop()
- else:
- self.thread = threading.Thread(target=self.trainer.start, args=(self.on_epoch_finish, self.on_update_neural_net_info))
- self.thread.start()
- # self.trainer.start(self.on_epoch_finish, self.on_update_neural_net_info)
- self.update()
-
- def start(self):
- self.thread = threading.Thread(target=self.trainer.start)
- self.thread.start()
- # self.trainer.start(on_epoch_finished=self.update_training_data)
-
- def on_epoch_finish(self, epoch):
- print("Updating the epoch")
- self.update()
-
- def on_prev_epoch(self):
- pass
- def on_next_epoch(self):
- pass
|