Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

32 řádky
1.1KB

  1. import struct
  2. import numpy as np
  3. from neural_net.neural_net import ModelData
  4. class MNISTModelData(ModelData):
  5. def __init__(self, fn_train_inputs, fn_train_targets, fn_test_inputs, fn_test_targets):
  6. super().__init__(
  7. self._get_images_from_idx(fn_train_inputs),
  8. self._get_labels_from_idx(fn_train_targets),
  9. self._get_images_from_idx(fn_test_inputs),
  10. self._get_labels_from_idx(fn_test_targets)
  11. )
  12. print(np.array(self.test_inputs[0]).reshape((28, 28)))
  13. def _get_images_from_idx(self, file):
  14. with open(file, 'rb') as f:
  15. magic, size = struct.unpack(">II", f.read(8))
  16. nrows, ncols = struct.unpack(">II", f.read(8))
  17. data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
  18. data = data.reshape((size, nrows * ncols)) / 255
  19. return 1 - data
  20. def _get_labels_from_idx(self, file):
  21. with open(file, 'rb') as f:
  22. magic, size = struct.unpack(">II", f.read(8))
  23. data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
  24. return data