You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 line
2.2KB

  1. from neural_net.epoch import Epoch
  2. from neural_net.neural_net import NeuralNet, ModelData
  3. class NeuralNetTrainer:
  4. def __init__(self, neural_net: NeuralNet, model_data: ModelData, learning_rate: float, batch_size: int):
  5. self.neural_net = neural_net
  6. self.model_data = model_data
  7. self.is_running = False
  8. self.epoch_history = []
  9. self.learning_rate = learning_rate
  10. self.batch_size = batch_size
  11. def set_learning_rate(self, learning_rate: float):
  12. self.learning_rate = learning_rate
  13. def set_batch_size(self, batch_size: int):
  14. self.batch_size = batch_size
  15. def run_epoch(self):
  16. epoch = Epoch(len(self.epoch_history),
  17. self.model_data.training_inputs,
  18. self.model_data.training_labels,
  19. self.learning_rate,
  20. self.batch_size
  21. )
  22. self._train_one_epoch(epoch)
  23. return epoch
  24. def start(self, on_epoch_finish=None, on_finish=None):
  25. self.is_running = True
  26. while True:
  27. # Stop function was called causing the trainer to reset
  28. if not self.is_running:
  29. break
  30. # Perform one epoch of training
  31. # In the future, we will apply a learning-rate algorithm
  32. epoch = self.run_epoch()
  33. if on_epoch_finish is not None:
  34. on_epoch_finish(epoch)
  35. if on_finish is not None:
  36. on_finish()
  37. self.stop()
  38. def stop(self):
  39. if self.is_running:
  40. self.is_running = False
  41. def _train_one_epoch(self, epoch: Epoch):
  42. epoch.start()
  43. for batch in epoch.batches:
  44. batch.predictions = self.neural_net.forward(batch.inputs)
  45. dL_dout = self.neural_net.loss_derivative(batch.predictions, batch.labels)
  46. layer_dl_gradients, layer_dl_biases, layer_weights, layer_biases = self.neural_net.backward(dL_dout, epoch)
  47. epoch.layer_dl_gradients.append(layer_dl_gradients)
  48. epoch.layer_dl_biases.append(layer_dl_biases)
  49. epoch.finish(self.neural_net)
  50. epoch.loss = self.neural_net.loss(epoch.all_predictions(), epoch.all_labels())
  51. self.epoch_history.append(epoch)