import tkinter as tk from ui.app_state import AppState from ui.components.label_with_refresh import LabelWithRefresh class NeuralNetInfoWidget(tk.Frame): def __init__(self, parent, app_state: AppState): super().__init__(parent) self.app_state = app_state self.update_ui() def clear_ui(self): for widget in self.winfo_children(): widget.destroy() def update_ui(self): self.clear_ui() row = 0 if self.app_state.neural_net is not None: for layer in self.app_state.neural_net.layers: (tk.Label(self, text=f"{layer.type} {layer.index}") .grid(column=0, row=row, padx=10, pady=5, sticky='w')) tk.Label(self, text=f"{layer.input_dim} -> {layer.output_dim} neurons").grid(column=1, row=row, padx=10, pady=5, sticky='e') row += 1 button_state = tk.DISABLED if self.app_state.model_data is not None: button_state = tk.NORMAL tk.Label(self, text="Accuracy:").grid(column=0, row=row, padx=10, pady=5, sticky='w') last_accuracy = "NA" if self.app_state.neural_net.last_accuracy is not None: last_accuracy = f"{self.app_state.neural_net.last_accuracy * 100:.2f}%" self.accuracy_label = LabelWithRefresh(self, last_accuracy, callback=self.recalculate_accuracy, initial_state=button_state) self.accuracy_label.grid(column=1, row=row, padx=10, pady=5, sticky='e') row += 1 tk.Label(self, text="Current Loss:").grid(column=0, row=row, padx=10, pady=5, sticky='w') last_loss = "NA" if self.app_state.neural_net.last_loss is not None: last_loss = f"{self.app_state.neural_net.last_loss:.4f}" self.loss_label = LabelWithRefresh(self, last_loss, callback=self.recalculate_loss, initial_state=button_state) self.loss_label.grid(column=1, row=row, padx=10, pady=5, sticky='e') row += 1 def recalculate_accuracy(self): self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) self.update_ui() def recalculate_loss(self): self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels) self.update_ui()