You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 line
3.2KB

  1. import unittest
  2. import numpy as np
  3. from neural_net.mnist import MNISTNeuralNet
  4. from neural_net.functions.loss import cross_entropy_loss
  5. # noinspection PyMethodMayBeStatic
  6. class MNISTNeuralNetTests(unittest.TestCase):
  7. def test_loss(self):
  8. mnist = MNISTNeuralNet()
  9. # Sample predictions and labels for testing the loss function
  10. predictions = np.array([[0.1, 0.2, 0.7], # Example of a softmax output (probabilities)
  11. [0.2, 0.6, 0.2]])
  12. # Corresponding labels (correct class indices)
  13. labels = np.array([2, 1]) # Labels are class indices (not one-hot)
  14. # Expected loss (you may need to compute this manually to verify correctness)
  15. expected_loss = cross_entropy_loss(predictions, labels) # Replace with the actual expected loss value
  16. # Call the loss function
  17. computed_loss = mnist.loss(predictions, labels)
  18. # Assert that the computed loss matches the expected loss
  19. self.assertAlmostEqual(computed_loss, expected_loss, places=5, msg="Loss function is incorrect")
  20. def test_derivative_loss(self):
  21. mnist = MNISTNeuralNet()
  22. # Sample predictions and labels for testing the derivative of the loss function
  23. predictions = np.array([[0.1, 0.2, 0.7], # Example of softmax output (probabilities)
  24. [0.2, 0.6, 0.2]])
  25. # Corresponding labels (correct class indices)
  26. labels = np.array([2, 1]) # Labels are class indices
  27. # Expected derivative of loss (manually computed or from a trusted source)
  28. expected_derivative = np.array([[0.1, 0.2, -0.3], # Replace with actual expected gradient
  29. [0.2, -0.4, 0.2]])
  30. # Call the derivative loss function
  31. computed_derivative = mnist.loss_derivative(predictions, labels)
  32. # Assert that the computed derivative matches the expected derivative
  33. np.testing.assert_array_almost_equal(computed_derivative, expected_derivative, decimal=5,
  34. err_msg="Derivative of loss function is incorrect")
  35. def test_derivative_loss2(self):
  36. mnist = MNISTNeuralNet()
  37. # Given outputs
  38. outputs = np.array([
  39. [0.06873367, 0.043651, 0.043651, 0.05235898, 0.043651, 0.043651,
  40. 0.043651, 0.043651, 0.0563062, 0.043651],
  41. [0.043651, 0.043651, 0.05704588, 0.0551587, 0.05460022, 0.043651,
  42. 0.043651, 0.043651, 0.07723706, 0.05474726]
  43. ])
  44. # Labels
  45. labels = [7, 2]
  46. num_classes = 10
  47. # Convert labels to one-hot encoding
  48. labels_one_hot = np.zeros((len(labels), num_classes))
  49. for i, label in enumerate(labels):
  50. labels_one_hot[i, label] = 1
  51. # Calculate the expected loss derivative
  52. expected_loss_derivative = outputs - labels_one_hot
  53. # Call the derivative loss function
  54. computed_loss_derivative = mnist.loss_derivative(outputs, labels)
  55. # Assert that the computed derivative matches the expected derivative
  56. np.testing.assert_array_almost_equal(computed_loss_derivative, expected_loss_derivative, decimal=5,
  57. err_msg="Derivative of loss function is incorrect")