No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

54 líneas
2.4KB

  1. import tkinter as tk
  2. from ui.app_state import AppState
  3. from ui.components.label_with_refresh import LabelWithRefresh
  4. class NeuralNetInfoWidget(tk.Frame):
  5. def __init__(self, parent, app_state: AppState):
  6. super().__init__(parent)
  7. self.app_state = app_state
  8. self.update_ui()
  9. def clear_ui(self):
  10. for widget in self.winfo_children():
  11. widget.destroy()
  12. def update_ui(self):
  13. self.clear_ui()
  14. row = 0
  15. if self.app_state.neural_net is not None:
  16. for layer in self.app_state.neural_net.layers:
  17. (tk.Label(self, text=f"{layer.type} {layer.index}")
  18. .grid(column=0, row=row, padx=10, pady=5, sticky='w'))
  19. tk.Label(self, text=f"{layer.input_dim} -> {layer.output_dim} neurons").grid(column=1, row=row, padx=10, pady=5, sticky='e')
  20. row += 1
  21. button_state = tk.DISABLED
  22. if self.app_state.model_data is not None:
  23. button_state = tk.NORMAL
  24. tk.Label(self, text="Accuracy:").grid(column=0, row=row, padx=10, pady=5, sticky='w')
  25. last_accuracy = "NA"
  26. if self.app_state.neural_net.last_accuracy is not None:
  27. last_accuracy = f"{self.app_state.neural_net.last_accuracy * 100:.2f}%"
  28. self.accuracy_label = LabelWithRefresh(self, last_accuracy, callback=self.recalculate_accuracy, initial_state=button_state)
  29. self.accuracy_label.grid(column=1, row=row, padx=10, pady=5, sticky='e')
  30. row += 1
  31. tk.Label(self, text="Current Loss:").grid(column=0, row=row, padx=10, pady=5, sticky='w')
  32. last_loss = "NA"
  33. if self.app_state.neural_net.last_loss is not None:
  34. last_loss = f"{self.app_state.neural_net.last_loss:.4f}"
  35. self.loss_label = LabelWithRefresh(self, last_loss, callback=self.recalculate_loss, initial_state=button_state)
  36. self.loss_label.grid(column=1, row=row, padx=10, pady=5, sticky='e')
  37. row += 1
  38. def recalculate_accuracy(self):
  39. self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  40. self.update_ui()
  41. def recalculate_loss(self):
  42. self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  43. self.update_ui()