You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.8KB

  1. import numpy as np
  2. # Your softmax outputs
  3. outputs = np.array([
  4. [
  5. [
  6. 0.90924643, 0.0, 0.26800049, 0.0, 0.14153697, 0.07644807,
  7. 0.0, 0.63928418, 0.14899383, 0.29679539, 0.29560591, 0.46324955,
  8. 0.38955634, 0.0, 0.05094845, 0.0, 0.0, 0.26734416, 0.0,
  9. 0.28399383, 0.0429699, 0.68988006, 0.0, 0.0, 0.0, 0.02901288,
  10. 0.0, 0.01076904, 0.0, 0.41230365, 0.58630857, 0.0, 0.29906131,
  11. 0.0, 0.00339327, 0.47909497, 0.07787446, 0.0, 0.0, 0.0, 0.0,
  12. 0.0, 0.0, 0.59843748, 0.18691183, 0.0, 0.0, 0.0, 0.84100045,
  13. 0.24468988, 0.0144432, 0.0, 0.27832373, 0.0, 0.45574082,
  14. 0.16037272, 0.0, 0.28562163, 0.0, 0.0, 0.44667622, 0.0, 0.0,
  15. 0.29725156, 0.0, 0.01500714, 0.51253602, 0.18559459, 0.07919077,
  16. 0.0, 0.15155614, 0.0, 0.16996095, 0.26832836, 0.0, 0.56057083,
  17. 0.47535547, 0.0, 0.08280879, 0.0, 0.07266015, 0.43079376,
  18. 0.55633086, 0.0, 0.13123258, 0.33282808, 0.0, 0.73207594, 0.0,
  19. 0.08246748, 0.0, 0.0, 0.0, 0.03605279, 0.56645505, 0.0,
  20. 0.66074054, 0.0, 0.0, 0.07871833, 0.0, 0.0, 0.0, 0.0,
  21. 0.0, 0.0, 0.26077944, 0.0, 0.0, 0.19883228, 0.26075606,
  22. 0.0, 0.55120887, 0.0, 0.0, 0.13896239, 0.8079261, 0.0
  23. ],
  24. [
  25. 1.3890246, 0.0, 0.0176582, 0.41937874, 0.01668789, 0.08115837,
  26. 0.0, 0.0, 0.0, 0.03283852, 0.0, 0.28331658, 0.0, 0.56971081,
  27. 1.29951652, 0.0, 0.05585489, 0.0, 0.0, 0.0, 0.4555721, 0.0,
  28. 0.0, 0.0, 1.13440652, 0.3462467, 0.53066361, 0.85311426,
  29. 0.13320967, 0.61478612, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04859889,
  30. 0.0, 0.0884254, 0.0, 0.56573542, 0.18211658, 0.0, 0.24407104,
  31. 0.0, 0.07133323, 0.0, 0.0, 0.98712028, 0.0, 0.06996351,
  32. 0.70575429, 0.30689567, 0.47709064, 0.07469221, 0.40548246,
  33. 0.09671662, 0.56150121, 0.0, 0.7116001, 0.57194077, 0.0,
  34. 0.10528511, 0.20317026, 0.03516737, 0.0, 0.0, 0.10198436,
  35. 0.0, 0.0, 0.0, 0.35702522, 0.0, 0.0, 0.32883485, 0.0,
  36. 0.0, 0.18996724, 0.0, 0.0, 0.0, 0.06601356, 0.0,
  37. 0.41925782, 0.0, 0.0, 0.07929863, 0.28089351, 0.0,
  38. 0.25405591, 0.09954264, 1.05735563, 0.0, 0.57732162, 0.0,
  39. 0.05791431, 0.0, 0.42524903, 0.0, 0.0, 0.0, 0.0, 0.0,
  40. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13586283, 0.23484103,
  41. 0.69677156, 0.0, 0.0, 0.08609836, 0.89882583
  42. ]
  43. ]
  44. ])
  45. labels = [7, 2]
  46. # Convert labels to one-hot encoding
  47. num_classes = 10
  48. labels_one_hot = np.zeros((len(labels), num_classes))
  49. for i, label in enumerate(labels):
  50. labels_one_hot[i, label] = 1
  51. # Calculate the loss derivative
  52. loss_derivative = outputs - labels_one_hot
  53. print("Loss Derivative:")
  54. print(loss_derivative)