|
- import os
- import tkinter as tk
-
- from data.mnist_loader import MNISTModelData
- from ui.app_state import AppState
- from ui.front_page.sections.model_overview_section import NeuralNetInfo
- from ui.front_page.sections.test_model_section import TestModelSection
- from ui.front_page.sections.training_section import TrainingSection
-
- class FrontPage(tk.Frame):
- def __init__(self, parent, app_state: AppState):
- super().__init__(parent)
- self.parent = parent
- self.app_state = app_state
-
- self.main_frame = None
- self.neural_net_info = None
- self.model_actions_frame = None
- self.start_training_section = None
- self.test_model_section = None
- self.training_section = None
- self.test_model_section = None
- self.create_ui()
-
- def create_ui(self):
- (tk.Label(self, text="Welcome to MNIST Learning Center", font=("Arial", 16))
- .pack(side=tk.TOP, fill=tk.BOTH, expand=False, padx=5))
-
- self.main_frame = tk.Frame(self)
- self.main_frame.pack(fill=tk.BOTH, expand=True)
-
- self.neural_net_info = NeuralNetInfo(self.main_frame, self.app_state, self.on_model_loaded, self.on_data_loaded)
- self.neural_net_info.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5)
- self.load_model_actions_frame()
-
- def update(self):
- if self.neural_net_info is not None:
- self.neural_net_info.update()
- self.load_model_actions_frame()
-
- def load_model_actions_frame(self):
- if self.model_actions_frame is None and self.app_state.neural_net is not None and self.app_state.model_data is not None:
- self.model_actions_frame = tk.Frame(self.main_frame)
- self.model_actions_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, padx=5)
-
- self.training_section = TrainingSection(self.model_actions_frame, self.app_state, self.after_training)
- self.training_section.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
-
- self.test_model_section = TestModelSection(self.model_actions_frame, self.app_state)
- self.test_model_section.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=5)
- else:
- if self.test_model_section is not None:
- self.test_model_section.update()
- if self.training_section is not None:
- self.training_section.update()
-
- def on_data_loaded(self):
- print("Data loaded")
- self.update()
-
- def on_model_loaded(self):
- print("Model loaded")
- self.update()
-
- def load_training_data(self):
- data_folder = "/projects/learning/datasets/minst"
- self.app_state.model_data = MNISTModelData(
- os.path.join(data_folder, "train-images-idx3-ubyte"),
- os.path.join(data_folder, "train-labels-idx1-ubyte"),
- os.path.join(data_folder, "t10k-images-idx3-ubyte"),
- os.path.join(data_folder, "t10k-labels-idx1-ubyte")
- )
- self.update()
-
- def after_training(self):
- self.update()
|