|
- import struct
-
- import numpy as np
-
- from neural_net.neural_net import ModelData
-
- class MNISTModelData(ModelData):
- def __init__(self, fn_train_inputs, fn_train_targets, fn_test_inputs, fn_test_targets):
- super().__init__(
- self._get_images_from_idx(fn_train_inputs),
- self._get_labels_from_idx(fn_train_targets),
- self._get_images_from_idx(fn_test_inputs),
- self._get_labels_from_idx(fn_test_targets)
- )
- print(np.array(self.test_inputs[0]).reshape((28, 28)))
-
- def _get_images_from_idx(self, file):
- with open(file, 'rb') as f:
- magic, size = struct.unpack(">II", f.read(8))
- nrows, ncols = struct.unpack(">II", f.read(8))
-
- data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
- data = data.reshape((size, nrows * ncols)) / 255
-
- return 1 - data
-
- def _get_labels_from_idx(self, file):
- with open(file, 'rb') as f:
- magic, size = struct.unpack(">II", f.read(8))
- data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
- return data
|