import struct import numpy as np from neural_net.neural_net import ModelData class MNISTModelData(ModelData): def __init__(self, fn_train_inputs, fn_train_targets, fn_test_inputs, fn_test_targets): super().__init__( self._get_images_from_idx(fn_train_inputs), self._get_labels_from_idx(fn_train_targets), self._get_images_from_idx(fn_test_inputs), self._get_labels_from_idx(fn_test_targets) ) print(np.array(self.test_inputs[0]).reshape((28, 28))) def _get_images_from_idx(self, file): with open(file, 'rb') as f: magic, size = struct.unpack(">II", f.read(8)) nrows, ncols = struct.unpack(">II", f.read(8)) data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>')) data = data.reshape((size, nrows * ncols)) / 255 return 1 - data def _get_labels_from_idx(self, file): with open(file, 'rb') as f: magic, size = struct.unpack(">II", f.read(8)) data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>')) return data