import os import tkinter as tk from data.mnist_loader import MNISTModelData from neural_net.mnist import MNISTNeuralNet from ui.app_state import AppState from ui.front_page.sections.neural_net_info_widget import NeuralNetInfoWidget class NeuralNetInfo(tk.LabelFrame): def __init__(self, parent, app_state: AppState, on_load_model, on_load_data): super().__init__(parent, text="Model overview") self.app_state = app_state self.cb_on_load_model = on_load_model self.cb_on_load_data = on_load_data self.create_ui() def create_ui(self): # Option to load model (could be a file dialog or dropdown in future) self.load_model_button = tk.Button(self, text="Load model", command=self.on_load_model) self.load_model_button.pack(padx=5, pady=5, side=tk.TOP) if self.app_state.neural_net is None: self.model_status = tk.Label(self, text="No model loaded") self.model_status.pack(padx=5, pady=5) else: self.load_model_button.config(text="Reload model") load_data_button = tk.Button(self, text="Load data", command=self.on_load_data) load_data_button.pack(padx=5, pady=5, side=tk.TOP) if self.app_state.model_data is None: self.data_status = tk.Label(self, text="No data loaded") self.data_status.pack(padx=5, pady=5) else: load_data_button.config(text="Reload data") self.neural_net_info = NeuralNetInfoWidget(self, self.app_state) self.neural_net_info.pack(padx=5, pady=5) def update(self): if self.app_state.neural_net is None and self.model_status is None: self.model_status = tk.Label(self, text="No model loaded") self.model_status.pack(padx=5, pady=5) load_data_button = tk.Button(self, text="Load data", command=self.on_load_data) load_data_button.pack(padx=5, pady=5, side=tk.TOP) if self.app_state.model_data is None: self.data_status = tk.Label(self, text="No data loaded") self.data_status.pack(padx=5, pady=5) else: load_data_button.config(text="Reload data") self.neural_net_info = NeuralNetInfoWidget(self, self.app_state) self.neural_net_info.pack(padx=5, pady=5) def on_load_data(self): data_folder = "/projects/learning/datasets/minst" self.app_state.model_data = MNISTModelData( os.path.join(data_folder, "train-images-idx3-ubyte"), os.path.join(data_folder, "train-labels-idx1-ubyte"), os.path.join(data_folder, "t10k-images-idx3-ubyte"), os.path.join(data_folder, "t10k-labels-idx1-ubyte") ) if self.app_state.neural_net is not None: self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) if self.cb_on_load_data is not None: self.cb_on_load_data() def on_load_model(self): self.app_state.neural_net = MNISTNeuralNet() if self.app_state.model_data is not None: self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) if self.cb_on_load_model is not None: self.cb_on_load_model()