Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

76 строки
3.5KB

  1. import os
  2. import tkinter as tk
  3. from data.mnist_loader import MNISTModelData
  4. from neural_net.mnist import MNISTNeuralNet
  5. from ui.app_state import AppState
  6. from ui.front_page.sections.neural_net_info_widget import NeuralNetInfoWidget
  7. class NeuralNetInfo(tk.LabelFrame):
  8. def __init__(self, parent, app_state: AppState, on_load_model, on_load_data):
  9. super().__init__(parent, text="Model overview")
  10. self.app_state = app_state
  11. self.cb_on_load_model = on_load_model
  12. self.cb_on_load_data = on_load_data
  13. self.create_ui()
  14. def create_ui(self):
  15. # Option to load model (could be a file dialog or dropdown in future)
  16. self.load_model_button = tk.Button(self, text="Load model", command=self.on_load_model)
  17. self.load_model_button.pack(padx=5, pady=5, side=tk.TOP)
  18. if self.app_state.neural_net is None:
  19. self.model_status = tk.Label(self, text="No model loaded")
  20. self.model_status.pack(padx=5, pady=5)
  21. else:
  22. self.load_model_button.config(text="Reload model")
  23. load_data_button = tk.Button(self, text="Load data", command=self.on_load_data)
  24. load_data_button.pack(padx=5, pady=5, side=tk.TOP)
  25. if self.app_state.model_data is None:
  26. self.data_status = tk.Label(self, text="No data loaded")
  27. self.data_status.pack(padx=5, pady=5)
  28. else:
  29. load_data_button.config(text="Reload data")
  30. self.neural_net_info = NeuralNetInfoWidget(self, self.app_state)
  31. self.neural_net_info.pack(padx=5, pady=5)
  32. def update(self):
  33. if self.app_state.neural_net is None and self.model_status is None:
  34. self.model_status = tk.Label(self, text="No model loaded")
  35. self.model_status.pack(padx=5, pady=5)
  36. load_data_button = tk.Button(self, text="Load data", command=self.on_load_data)
  37. load_data_button.pack(padx=5, pady=5, side=tk.TOP)
  38. if self.app_state.model_data is None:
  39. self.data_status = tk.Label(self, text="No data loaded")
  40. self.data_status.pack(padx=5, pady=5)
  41. else:
  42. load_data_button.config(text="Reload data")
  43. self.neural_net_info = NeuralNetInfoWidget(self, self.app_state)
  44. self.neural_net_info.pack(padx=5, pady=5)
  45. def on_load_data(self):
  46. data_folder = "/projects/learning/datasets/minst"
  47. self.app_state.model_data = MNISTModelData(
  48. os.path.join(data_folder, "train-images-idx3-ubyte"),
  49. os.path.join(data_folder, "train-labels-idx1-ubyte"),
  50. os.path.join(data_folder, "t10k-images-idx3-ubyte"),
  51. os.path.join(data_folder, "t10k-labels-idx1-ubyte")
  52. )
  53. if self.app_state.neural_net is not None:
  54. self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  55. self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  56. if self.cb_on_load_data is not None:
  57. self.cb_on_load_data()
  58. def on_load_model(self):
  59. self.app_state.neural_net = MNISTNeuralNet()
  60. if self.app_state.model_data is not None:
  61. self.app_state.neural_net.recalculate_loss(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  62. self.app_state.neural_net.recalculate_accuracy(self.app_state.model_data.test_inputs, self.app_state.model_data.test_labels)
  63. if self.cb_on_load_model is not None:
  64. self.cb_on_load_model()