|
- import os.path
-
- from data.mnist_loader import MNISTModelData
- from neural_net.mnist import MNISTNeuralNet
- from neural_net.neural_net import NeuralNet, ModelData
-
- class AppState:
- def __init__(self, auto_load=False):
- self.trainers = []
- if auto_load:
- self.neural_net: NeuralNet = MNISTNeuralNet()
- data_folder = "/projects/learning/datasets/minst"
- self.model_data: ModelData = MNISTModelData(
- os.path.join(data_folder, "train-images-idx3-ubyte"),
- os.path.join(data_folder, "train-labels-idx1-ubyte"),
- os.path.join(data_folder, "t10k-images-idx3-ubyte"),
- os.path.join(data_folder, "t10k-labels-idx1-ubyte")
- )
- self.neural_net.recalculate_accuracy(self.model_data.test_inputs, self.model_data.test_labels)
- self.neural_net.recalculate_loss(self.model_data.test_inputs, self.model_data.test_labels)
- else:
- self.neural_net: NeuralNet = None
- self.model_data: ModelData = None
|