Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

24 lignes
1.0KB

  1. import os.path
  2. from data.mnist_loader import MNISTModelData
  3. from neural_net.mnist import MNISTNeuralNet
  4. from neural_net.neural_net import NeuralNet, ModelData
  5. class AppState:
  6. def __init__(self, auto_load=False):
  7. self.trainers = []
  8. if auto_load:
  9. self.neural_net: NeuralNet = MNISTNeuralNet()
  10. data_folder = "/projects/learning/datasets/minst"
  11. self.model_data: ModelData = MNISTModelData(
  12. os.path.join(data_folder, "train-images-idx3-ubyte"),
  13. os.path.join(data_folder, "train-labels-idx1-ubyte"),
  14. os.path.join(data_folder, "t10k-images-idx3-ubyte"),
  15. os.path.join(data_folder, "t10k-labels-idx1-ubyte")
  16. )
  17. self.neural_net.recalculate_accuracy(self.model_data.test_inputs, self.model_data.test_labels)
  18. self.neural_net.recalculate_loss(self.model_data.test_inputs, self.model_data.test_labels)
  19. else:
  20. self.neural_net: NeuralNet = None
  21. self.model_data: ModelData = None