示例

  1. >>> a = torch.randn(4, 8, 8)
  2. >>> a.requires_grad
  3. False
  4. >>> a.grad
  5. >>> u, s, v = torch.svd(a)
  6. >>> u.size(), s.size(), v.size()
  7. (torch.Size([4, 8, 8]), torch.Size([4, 8]), torch.Size([4, 8, 8]))
  8. >>> d = torch.dist(a, u@s.diag_embed()@v.transpose(1, 2))
  9. >>> d
  10. tensor(6.1367e-06)
  11. >>> d.backward()
  12. Traceback (most recent call last):
  13. File "<stdin>", line 1, in <module>
  14. File "/home/lart/.conda/envs/pt12/lib/python3.6/site-packages/torch/tensor.py", line 118, in backward
  15. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  16. File "/home/lart/.conda/envs/pt12/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
  17. allow_unreachable=True) # allow_unreachable flag
  18. RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
  1. >>> a.requires_grad_(True)
  2. tensor([[[ 5.0926e-01, -5.7983e-01, 7.3646e-01, 8.3811e-01, 1.1680e-01,
  3. -1.3508e+00, -5.9514e-01, -4.2688e-01],
  4. [ 9.9339e-01, 5.0686e-01, -9.3109e-01, 1.6483e+00, 1.8901e-01,
  5. -4.1816e-01, -1.5053e+00, -8.3305e-01],
  6. [-1.0060e+00, 5.2788e-01, 7.4413e-01, 5.1731e-01, -2.1892e+00,
  7. -2.0613e+00, -9.1691e-01, -8.0884e-02],
  8. [-6.8767e-01, -2.8735e+00, 7.3318e-01, 3.1855e-03, 3.2300e-01,
  9. 3.1714e-01, -3.6759e-01, 1.0884e+00],
  10. [-1.1363e-01, 4.7537e-01, -5.2829e-01, 7.2855e-02, -1.9361e-01,
  11. 7.7375e-01, -5.7489e-01, -3.1300e-01],
  12. [-7.4154e-01, -6.4932e-01, -1.7089e+00, -3.0536e+00, -5.8849e-01,
  13. -8.0302e-01, 8.5439e-01, -3.2681e-01],
  14. [-1.7654e+00, 1.6885e+00, -7.8303e-01, -1.1855e+00, -5.2198e-01,
  15. -1.6007e+00, -4.4931e-01, 1.3803e+00],
  16. [ 9.1247e-01, -9.5255e-01, 1.4173e+00, 4.8889e-01, 3.5064e-01,
  17. 3.5041e-02, 5.3359e-01, -1.0892e+00]],
  18. [[-7.0987e-01, -5.0227e-01, 7.3428e-01, -1.4305e+00, -8.9247e-01,
  19. 1.0328e+00, 7.6211e-01, -9.2013e-01],
  20. [ 1.8573e+00, -1.9633e-01, -2.2052e-01, 7.8514e-01, -7.0751e-01,
  21. 1.3906e+00, -1.8284e+00, 5.4344e-02],
  22. [ 5.9049e-01, 2.8572e-01, 1.0436e+00, 1.4244e+00, 6.6587e-01,
  23. -9.6773e-01, 5.7693e-02, 6.8744e-01],
  24. [ 1.8168e+00, -8.1121e-01, -9.6194e-01, -4.5015e-01, 1.0638e+00,
  25. 2.2808e+00, -1.7837e-01, 5.9948e-01],
  26. [-7.6185e-02, 5.2174e-01, 1.0115e+00, 1.6414e+00, 1.0200e-02,
  27. -3.3982e-01, 2.4742e+00, 5.1641e-01],
  28. [-7.8309e-01, -3.2473e-01, 4.3643e-01, -2.4305e+00, 2.2542e+00,
  29. 1.2450e-01, -1.1120e+00, -8.5926e-01],
  30. [-3.5841e-01, -9.4043e-02, 5.5771e-02, 9.6267e-01, 6.9952e-03,
  31. -3.7878e-01, -4.0624e-01, 1.2696e-01],
  32. [ 2.1680e+00, -2.0111e-01, -5.6822e-02, 5.3845e-01, -4.1230e-02,
  33. 3.1459e+00, -1.0781e+00, 1.1412e+00]],
  34. [[ 4.7876e-02, 1.2553e+00, 2.3790e+00, 9.7567e-01, 8.4884e-01,
  35. 2.5797e-02, -5.1782e-01, -1.6589e+00],
  36. [ 1.9872e-01, 2.6252e-01, 1.5556e+00, -5.1998e-01, 4.8064e-02,
  37. -2.9763e-01, 4.0865e-01, -7.6408e-01],
  38. [-1.2351e+00, -9.3510e-01, -3.3864e-01, 1.5223e-01, -1.0920e+00,
  39. -1.4765e+00, 1.6345e+00, -7.3910e-01],
  40. [ 4.4334e-01, 9.9712e-01, -1.7089e+00, 6.7556e-01, 3.3336e-01,
  41. -1.1793e+00, 2.1877e-01, -1.1102e+00],
  42. [-8.3336e-01, 1.0436e-02, 3.1273e-01, -8.4080e-02, -2.3449e-01,
  43. -1.1146e+00, 4.9459e-01, 3.7987e-01],
  44. [-1.6838e+00, 7.7992e-01, -9.8158e-01, -1.7515e+00, 3.5891e-01,
  45. -1.5335e-01, -1.7358e+00, 4.6637e-01],
  46. [-1.5487e+00, -2.7994e-01, 4.7985e-01, 2.0904e-01, 5.8344e-01,
  47. 6.5077e-01, 1.2345e+00, 2.0636e+00],
  48. [ 1.0334e+00, 1.0952e-01, 8.3518e-01, 4.5615e-01, 5.1996e-01,
  49. -1.9327e-02, -1.5506e+00, 6.6045e-01]],
  50. [[-1.5528e+00, -7.9126e-01, -2.2267e+00, -1.0876e+00, 1.1706e+00,
  51. 1.1638e+00, -8.7892e-01, -7.9665e-02],
  52. [-1.6719e+00, 3.3261e-01, 1.3548e+00, 3.3109e-01, 1.2272e+00,
  53. -2.0417e-01, 6.9749e-01, 8.0466e-01],
  54. [ 1.4222e+00, -1.8551e+00, -5.3144e-01, 1.1752e+00, 5.2877e-01,
  55. 5.3952e-01, 6.6802e-01, -2.2321e-01],
  56. [-2.2727e-01, -1.1718e+00, -1.0929e+00, -8.2438e-02, -2.8256e-01,
  57. -1.4067e+00, -9.2816e-01, -1.1016e+00],
  58. [ 8.6106e-01, -8.1874e-02, -5.6385e-02, -2.8466e+00, -1.8488e-01,
  59. -3.0176e-01, -4.6510e-01, 3.4096e-01],
  60. [ 5.0758e-02, 2.9008e-01, 1.1458e+00, -1.4306e-01, 4.3022e-01,
  61. -2.1033e+00, 1.0509e+00, 2.7715e-01],
  62. [ 4.5102e-01, -1.0666e+00, 1.4706e+00, 3.0153e-01, 1.1718e+00,
  63. -4.4404e-01, -1.3700e-01, 2.1973e+00],
  64. [ 8.6106e-01, 9.2344e-01, -1.4858e-02, 2.7544e-03, 4.6628e-01,
  65. -9.7355e-01, -3.2367e-01, 1.3770e+00]]], requires_grad=True)
  66. >>> a.requires_grad
  67. True
  68. >>> a.grad
  69. >>> u, s, v = torch.svd(a)
  70. >>> d = torch.dist(a, u@s.diag_embed()@v.transpose(1, 2))
  71. >>> d
  72. tensor(6.1367e-06, grad_fn=<DistBackward>)
  73. >>> d.backward()
  74. >>> a.grad
  75. tensor([[[ 1.8626e-08, -3.7253e-08, 8.9407e-08, -5.4017e-08, 7.4506e-09,
  76. -4.8429e-08, -5.2154e-08, 0.0000e+00],
  77. [ 5.7742e-08, 7.4506e-08, -1.9372e-07, 1.0803e-07, 3.7253e-09,
  78. 3.7253e-09, -6.3330e-08, -1.1921e-07],
  79. [-7.4506e-09, 2.6077e-08, 4.6566e-08, -4.8429e-08, -1.4901e-08,
  80. -3.7253e-08, 1.4901e-08, 1.6764e-08],
  81. [ 4.0978e-08, 2.2352e-08, -2.2352e-08, 4.4034e-08, 3.7253e-09,
  82. 7.4506e-09, -1.4901e-08, -2.2352e-08],
  83. [-2.0489e-08, 2.0489e-08, -6.7055e-08, 7.6834e-09, -1.1176e-08,
  84. 3.3528e-08, 1.1176e-08, -2.7940e-09],
  85. [-3.7253e-09, 8.3819e-08, -1.1921e-07, 3.7253e-08, -5.9605e-08,
  86. -5.2154e-08, 7.4506e-09, -5.9605e-08],
  87. [-2.9802e-08, 7.8231e-08, -1.1176e-08, -5.2154e-08, -3.3528e-08,
  88. -1.4901e-08, 1.0245e-08, 5.5879e-09],
  89. [ 2.2352e-08, -8.1956e-08, 1.5646e-07, -1.4901e-08, 1.3039e-08,
  90. -7.4506e-09, -7.4506e-09, 1.4901e-08]],
  91. [[-4.6566e-08, -3.7253e-08, 3.7253e-08, 1.4901e-08, -3.1665e-08,
  92. -1.4901e-08, 1.1176e-08, 1.8626e-09],
  93. [-1.8626e-08, 1.8626e-08, 7.4506e-09, 2.9802e-08, -2.4214e-08,
  94. -2.9802e-08, 7.4506e-09, 7.4506e-09],
  95. [ 1.1176e-08, 3.3528e-08, 7.4506e-09, 7.4506e-09, 2.9802e-08,
  96. -7.4506e-09, 1.4901e-08, 0.0000e+00],
  97. [-2.2352e-08, -1.4901e-08, -2.9802e-08, -5.9605e-08, 1.4901e-08,
  98. 4.4703e-08, -2.9802e-08, -1.4901e-08],
  99. [ 1.8626e-08, 2.6077e-08, 3.7253e-08, -2.2352e-08, 7.4506e-09,
  100. -1.4901e-08, 3.7253e-08, -7.4506e-09],
  101. [-2.7940e-08, -8.3819e-09, 4.4703e-08, 2.9802e-08, 4.4703e-08,
  102. 1.8626e-08, -4.4703e-08, 4.4703e-08],
  103. [ 5.5879e-09, -3.7253e-09, 0.0000e+00, 1.4901e-08, 1.4901e-08,
  104. -9.3132e-09, 2.9802e-08, 1.1176e-08],
  105. [ 0.0000e+00, 4.4703e-08, -2.2352e-08, -2.9802e-08, -1.8626e-09,
  106. 0.0000e+00, -1.4901e-08, -2.2352e-08]],
  107. [[-1.3039e-08, 0.0000e+00, 3.7253e-08, -2.9802e-08, 0.0000e+00,
  108. 2.0489e-08, 7.4506e-09, -1.1176e-08],
  109. [ 6.1467e-08, 2.9802e-08, -4.4703e-08, -5.5879e-09, -1.8626e-08,
  110. -6.7055e-08, -2.9802e-08, -1.4901e-08],
  111. [-2.9802e-08, -7.8231e-08, 3.7253e-09, 5.0291e-08, -4.0978e-08,
  112. 1.8626e-08, 8.9407e-08, 0.0000e+00],
  113. [ 2.9802e-08, 2.9802e-08, -2.0489e-08, -1.4901e-08, -7.4506e-09,
  114. -4.8429e-08, 4.1910e-09, -3.7253e-09],
  115. [-1.1176e-08, 1.8626e-09, -1.2107e-08, 3.9116e-08, 1.8626e-09,
  116. -2.9802e-08, 1.4901e-08, 2.0489e-08],
  117. [-6.7055e-08, 7.4506e-08, 1.8626e-09, -9.6858e-08, 1.3039e-08,
  118. -2.0489e-08, -4.4703e-08, -2.2352e-08],
  119. [-6.3330e-08, -7.4506e-09, 5.2154e-08, 1.8626e-08, 0.0000e+00,
  120. -2.2352e-08, 2.6077e-08, 2.2352e-08],
  121. [ 3.3528e-08, 6.7521e-09, 2.2352e-08, 4.4703e-08, 2.6077e-08,
  122. -5.2154e-08, -1.4901e-08, 1.4901e-08]],
  123. [[-8.5682e-08, 1.1548e-07, 2.9802e-08, 1.8626e-08, -3.3528e-08,
  124. 1.8626e-08, 1.1176e-08, -1.3411e-07],
  125. [-2.0862e-07, 5.9605e-08, 4.4703e-08, 1.2666e-07, 1.8626e-08,
  126. 8.3819e-08, 4.4703e-08, -7.0781e-08],
  127. [ 3.1292e-07, -4.8429e-08, -2.6077e-08, -3.7253e-09, -3.3528e-08,
  128. -1.1548e-07, 5.5879e-09, -6.7055e-08],
  129. [ 1.3784e-07, -7.4506e-08, -8.3819e-09, 7.4506e-09, 4.0978e-08,
  130. -7.4506e-08, 0.0000e+00, -8.9407e-08],
  131. [ 2.2352e-07, 2.2119e-08, -9.3132e-10, 7.4506e-09, -8.5682e-08,
  132. -7.4506e-08, -2.9802e-08, 4.0978e-08],
  133. [-7.4506e-09, -2.4214e-08, 1.1176e-08, -9.1270e-08, 3.8184e-08,
  134. -4.4703e-08, 3.3528e-08, 2.9802e-08],
  135. [ 2.6077e-08, 3.7253e-08, 4.4703e-08, 3.7253e-08, -7.8231e-08,
  136. 1.4901e-08, 1.3039e-08, -2.9802e-08],
  137. [ 8.1956e-08, 3.9116e-08, 7.4506e-09, -9.3132e-09, -7.4506e-09,
  138. -1.8626e-09, -1.4901e-08, 7.4506e-09]]])

相关链接