No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

80 líneas
3.5KB

  1. import threading
  2. import tkinter as tk
  3. from neural_net.trainer import NeuralNetTrainer
  4. from ui.app_state import AppState
  5. from ui.components.number_slider import NumberSlider
  6. from ui.training_page.training_page import EpochInformation
  7. class TrainingSection(tk.LabelFrame):
  8. def __init__(self, parent, app_state: AppState, on_update_neural_net_info):
  9. super().__init__(parent, text="Model training")
  10. self.app_state = app_state
  11. self.on_update_neural_net_info = on_update_neural_net_info
  12. self.batch_size = tk.IntVar()
  13. self.batch_size.set(1000)
  14. self.batch_size_slider = None
  15. self.learning_rate = tk.DoubleVar()
  16. self.learning_rate.set(0.0001)
  17. self.learning_rate_slider = None
  18. self.btn_start_stop = None
  19. self.stop_button = None
  20. self.training_information_container: EpochInformation = None
  21. self.trainer: NeuralNetTrainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data,
  22. self.learning_rate.get(), self.batch_size.get())
  23. self.create_ui()
  24. def create_ui(self):
  25. tk.Label(self, text="Batch size:").grid(column=0, row=0, padx=10, pady=5, sticky='w')
  26. self.batch_size_slider = NumberSlider(self, self.batch_size, from_=100, to=10000, resolution=1)
  27. self.batch_size_slider.grid(column=1, row=0, padx=10, pady=5, sticky='w')
  28. tk.Label(self, text="Learning rate:").grid(column=0, row=1, padx=10, pady=5, sticky='w')
  29. self.learning_rate_slider = NumberSlider(self, self.learning_rate, from_=0.0001, to=0.1, resolution=0.0001)
  30. self.learning_rate_slider.grid(column=1, row=1, padx=10, pady=5, sticky='w')
  31. self.btn_prev_epoch = tk.Button(self, text="<<", command=self.on_prev_epoch)
  32. self.btn_prev_epoch.grid(column=0, row=2, padx=10, pady=10, sticky='w')
  33. self.btn_start_stop = tk.Button(self, text="Start", command=self.toggle_state)
  34. self.btn_start_stop.grid(column=1, row=2, padx=10, pady=10, sticky='w')
  35. self.btn_next_epoch = tk.Button(self, text=">>", command=self.on_next_epoch)
  36. self.btn_next_epoch.grid(column=2, row=2, padx=10, pady=10, sticky='w')
  37. def update(self):
  38. if self.trainer.is_running:
  39. if self.training_information_container is None:
  40. self.training_information_container = EpochInformation(self, self.trainer.epoch_history[-1])
  41. self.training_information_container.grid(column=0, row=5, padx=10, pady=10, sticky='e')
  42. self.btn_start_stop.config(text="Stop")
  43. else:
  44. print("Setting the epoch")
  45. self.training_information_container.set_epoch(self.trainer.epoch_history[-1])
  46. else:
  47. self.btn_start_stop.config(text="Start")
  48. def toggle_state(self):
  49. if self.trainer.is_running:
  50. self.trainer.stop()
  51. else:
  52. self.thread = threading.Thread(target=self.trainer.start, args=(self.on_epoch_finish, self.on_update_neural_net_info))
  53. self.thread.start()
  54. # self.trainer.start(self.on_epoch_finish, self.on_update_neural_net_info)
  55. self.update()
  56. def start(self):
  57. self.thread = threading.Thread(target=self.trainer.start)
  58. self.thread.start()
  59. # self.trainer.start(on_epoch_finished=self.update_training_data)
  60. def on_epoch_finish(self, epoch):
  61. print("Updating the epoch")
  62. self.update()
  63. def on_prev_epoch(self):
  64. pass
  65. def on_next_epoch(self):
  66. pass