|
- import tkinter as tk
-
- from ui.app_state import AppState
- from ui.components.label_with_refresh import LabelWithRefresh
-
-
- class NeuralNetInfoWidget(tk.Frame):
- def __init__(self, parent, app_state: AppState):
- super().__init__(parent)
- self.app_state = app_state
- self.update_ui()
-
- def clear_ui(self):
- for widget in self.winfo_children():
- widget.destroy()
-
- def update_ui(self):
- self.clear_ui()
- row = 0
- if self.app_state.neural_net is not None:
- for layer in self.app_state.neural_net.layers:
- (tk.Label(self, text=f"{layer.type} {layer.index}")
- .grid(column=0, row=row, padx=10, pady=5, sticky='w'))
- tk.Label(self, text=f"{layer.input_dim} -> {layer.output_dim} neurons").grid(column=1, row=row, padx=10, pady=5, sticky='e')
- row += 1
-
- button_state = tk.DISABLED
- if self.app_state.model_data is not None:
- button_state = tk.NORMAL
-
- tk.Label(self, text="Accuracy:").grid(column=0, row=row, padx=10, pady=5, sticky='w')
- last_accuracy = "NA"
- if self.app_state.neural_net.last_accuracy is not None:
- last_accuracy = f"{self.app_state.neural_net.last_accuracy * 100:.2f}%"
- self.accuracy_label = LabelWithRefresh(self, last_accuracy, callback=self.recalculate_accuracy, initial_state=button_state)
- self.accuracy_label.grid(column=1, row=row, padx=10, pady=5, sticky='e')
- row += 1
-
- tk.Label(self, text="Current Loss:").grid(column=0, row=row, padx=10, pady=5, sticky='w')
- last_loss = "NA"
- if self.app_state.neural_net.last_loss is not None:
- last_loss = f"{self.app_state.neural_net.last_loss:.4f}"
- self.loss_label = LabelWithRefresh(self, last_loss, callback=self.recalculate_loss, initial_state=button_state)
- self.loss_label.grid(column=1, row=row, padx=10, pady=5, sticky='e')
- row += 1
-
- def recalculate_accuracy(self):
- self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- self.update_ui()
-
- def recalculate_loss(self):
- self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
- self.update_ui()
|