import os import tkinter as tk from data.mnist_loader import MNISTModelData from ui.app_state import AppState from ui.front_page.sections.model_overview_section import NeuralNetInfo from ui.front_page.sections.test_model_section import TestModelSection from ui.front_page.sections.training_section import TrainingSection class FrontPage(tk.Frame): def __init__(self, parent, app_state: AppState): super().__init__(parent) self.parent = parent self.app_state = app_state self.main_frame = None self.neural_net_info = None self.model_actions_frame = None self.start_training_section = None self.test_model_section = None self.training_section = None self.test_model_section = None self.create_ui() def create_ui(self): (tk.Label(self, text="Welcome to MNIST Learning Center", font=("Arial", 16)) .pack(side=tk.TOP, fill=tk.BOTH, expand=False, padx=5)) self.main_frame = tk.Frame(self) self.main_frame.pack(fill=tk.BOTH, expand=True) self.neural_net_info = NeuralNetInfo(self.main_frame, self.app_state, self.on_model_loaded, self.on_data_loaded) self.neural_net_info.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5) self.load_model_actions_frame() def update(self): if self.neural_net_info is not None: self.neural_net_info.update() self.load_model_actions_frame() def load_model_actions_frame(self): if self.model_actions_frame is None and self.app_state.neural_net is not None and self.app_state.model_data is not None: self.model_actions_frame = tk.Frame(self.main_frame) self.model_actions_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, padx=5) self.training_section = TrainingSection(self.model_actions_frame, self.app_state, self.after_training) self.training_section.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5) self.test_model_section = TestModelSection(self.model_actions_frame, self.app_state) self.test_model_section.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=5) else: if self.test_model_section is not None: self.test_model_section.update() if self.training_section is not None: self.training_section.update() def on_data_loaded(self): print("Data loaded") self.update() def on_model_loaded(self): print("Model loaded") self.update() def load_training_data(self): data_folder = "/projects/learning/datasets/minst" self.app_state.model_data = MNISTModelData( os.path.join(data_folder, "train-images-idx3-ubyte"), os.path.join(data_folder, "train-labels-idx1-ubyte"), os.path.join(data_folder, "t10k-images-idx3-ubyte"), os.path.join(data_folder, "t10k-labels-idx1-ubyte") ) self.update() def after_training(self): self.update()