|
- import os
- import tkinter as tk
-
- from data.mnist_loader import MNISTModelData
- from neural_net.mnist import MNISTNeuralNet
- from ui.app_state import AppState
- from ui.front_page.sections.neural_net_info_widget import NeuralNetInfoWidget
-
- class NeuralNetInfo(tk.LabelFrame):
- def __init__(self, parent, app_state: AppState, on_load_model, on_load_data):
- super().__init__(parent, text="Model overview")
- self.app_state = app_state
- self.cb_on_load_model = on_load_model
- self.cb_on_load_data = on_load_data
- self.create_ui()
-
- def create_ui(self):
- # Option to load model (could be a file dialog or dropdown in future)
- self.load_model_button = tk.Button(self, text="Load model", command=self.on_load_model)
- self.load_model_button.pack(padx=5, pady=5, side=tk.TOP)
-
- if self.app_state.neural_net is None:
- self.model_status = tk.Label(self, text="No model loaded")
- self.model_status.pack(padx=5, pady=5)
- else:
- self.load_model_button.config(text="Reload model")
-
- load_data_button = tk.Button(self, text="Load data", command=self.on_load_data)
- load_data_button.pack(padx=5, pady=5, side=tk.TOP)
- if self.app_state.model_data is None:
- self.data_status = tk.Label(self, text="No data loaded")
- self.data_status.pack(padx=5, pady=5)
- else:
- load_data_button.config(text="Reload data")
-
- self.neural_net_info = NeuralNetInfoWidget(self, self.app_state)
- self.neural_net_info.pack(padx=5, pady=5)
-
- def update(self):
- if self.app_state.neural_net is None and self.model_status is None:
- self.model_status = tk.Label(self, text="No model loaded")
- self.model_status.pack(padx=5, pady=5)
-
- load_data_button = tk.Button(self, text="Load data", command=self.on_load_data)
- load_data_button.pack(padx=5, pady=5, side=tk.TOP)
- if self.app_state.model_data is None:
- self.data_status = tk.Label(self, text="No data loaded")
- self.data_status.pack(padx=5, pady=5)
- else:
- load_data_button.config(text="Reload data")
-
- self.neural_net_info = NeuralNetInfoWidget(self, self.app_state)
- self.neural_net_info.pack(padx=5, pady=5)
-
- def on_load_data(self):
- data_folder = "/projects/learning/datasets/minst"
- self.app_state.model_data = MNISTModelData(
- os.path.join(data_folder, "train-images-idx3-ubyte"),
- os.path.join(data_folder, "train-labels-idx1-ubyte"),
- os.path.join(data_folder, "t10k-images-idx3-ubyte"),
- os.path.join(data_folder, "t10k-labels-idx1-ubyte")
- )
- if self.app_state.neural_net is not None:
- self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- if self.cb_on_load_data is not None:
- self.cb_on_load_data()
-
- def on_load_model(self):
- self.app_state.neural_net = MNISTNeuralNet()
- if self.app_state.model_data is not None:
- self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- if self.cb_on_load_model is not None:
- self.cb_on_load_model()
|