25'ten fazla konu seçemezsiniz Konular bir harf veya rakamla başlamalı, kısa çizgiler ('-') içerebilir ve en fazla 35 karakter uzunluğunda olabilir.

104 satır
4.7KB

  1. import threading
  2. import tkinter as tk
  3. from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
  4. from matplotlib.figure import Figure
  5. from neural_net.epoch import Epoch
  6. from neural_net.trainer import NeuralNetTrainer
  7. from ui.app_state import AppState
  8. from ui.front_page.plots.gradients import GradientsPlot
  9. from ui.front_page.plots.layer_weights import LayerWeightsPlot
  10. from ui.front_page.plots.loss import LossPlot
  11. from ui.front_page.sections.training_information import EpochInformation
  12. class TrainingPage(tk.Frame):
  13. def __init__(self, parent, app_state: AppState, on_training_finished=None):
  14. super().__init__(parent)
  15. self.app_state = app_state
  16. self.on_training_finished = on_training_finished
  17. self.trainer: NeuralNetTrainer = None
  18. # trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data, learning_rate, nr_epochs)
  19. # self.app_state.trainers.append(trainer)
  20. # self.trainer = trainer
  21. self.create_ui()
  22. def start(self, learning_rate, nr_epochs, batch_size, callback=None):
  23. if self.trainer is not None:
  24. self.trainer.stop()
  25. self.trainer = NeuralNetTrainer(self.app_state.neural_net, self.app_state.model_data,
  26. learning_rate=learning_rate, nr_epochs=nr_epochs, batch_size=batch_size,
  27. on_epoch_callback=self.update_training_data,
  28. on_finished_callback=self.on_training_finished)
  29. self.trainer.on_epoch_callback = self.update_training_data
  30. self.thread = threading.Thread(target=self.trainer.start)
  31. self.thread.start()
  32. # self.trainer.start(on_epoch_finished=self.update_training_data)
  33. if callback is not None:
  34. callback()
  35. def update_training_data(self, training_run, data: Epoch):
  36. print(f"Updating training data {data.epoch}")
  37. if self.trainer.is_running:
  38. self.training_information_container.update_training_data(training_run, data)
  39. self.loss_plot.update_training_data(training_run, data)
  40. if data.epoch % 5 == 0:
  41. self.gradients_plot.update_training_data(training_run, data)
  42. self.layer0_weights_plot.update_training_data(training_run, data)
  43. self.layer1_weights_plot.update_training_data(training_run, data)
  44. def create_ui(self):
  45. # Training center
  46. self.training_information_container = EpochInformation(self, self.app_state.neural_net, self.trainer)
  47. self.training_information_container.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10,
  48. ipadx=10)
  49. actions_frame = tk.Frame(self)
  50. actions_frame.pack(side=tk.TOP, fill=tk.X, expand=False, pady=10, padx=10, ipady=10, ipadx=10)
  51. btn_text = "Pause"
  52. self.btn_toggle_pause = tk.Button(actions_frame, text=btn_text, command=self.toggle_state)
  53. self.btn_toggle_pause.pack(side=tk.LEFT)
  54. btn_stop = tk.Button(actions_frame, text="Stop", command=self.trainer.stop)
  55. btn_stop.pack(side=tk.LEFT)
  56. # Plot tabs
  57. plot_tab_control = tk.Notebook(self)
  58. plot_tab_control.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, pady=0, padx=0, ipady=0, ipadx=0)
  59. self.loss_plot = LossPlot(plot_tab_control, self.app_state.neural_net)
  60. plot_tab_control.add(self.loss_plot, text="Loss Function")
  61. self.gradients_plot = GradientsPlot(plot_tab_control, self.app_state.neural_net)
  62. plot_tab_control.add(self.gradients_plot, text="Gradients")
  63. self.layer0_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net,
  64. self.app_state.neural_net.layers[0],
  65. 11, 11)
  66. plot_tab_control.add(self.layer0_weights_plot, text="Weights layer 0")
  67. self.layer1_weights_plot = LayerWeightsPlot(plot_tab_control, self.app_state.neural_net,
  68. self.app_state.neural_net.layers[1],
  69. 2, 5)
  70. plot_tab_control.add(self.layer1_weights_plot, text="Weights layer 1")
  71. def toggle_state(self):
  72. self.trainer.toggle_state()
  73. if self.trainer.training_paused:
  74. self.btn_toggle_pause.config(text="Resume")
  75. else:
  76. self.btn_toggle_pause.config(text="Pause")
  77. @staticmethod
  78. def create_plot_figure(tab):
  79. figure = Figure()
  80. # Create a matplotlib canvas to display the plot
  81. canvas = FigureCanvasTkAgg(figure, tab)
  82. canvas.draw()
  83. canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
  84. return figure