示例
>>> a = torch.randn(4, 8, 8)
>>> a.requires_grad
False
>>> a.grad
>>> u, s, v = torch.svd(a)
>>> u.size(), s.size(), v.size()
(torch.Size([4, 8, 8]), torch.Size([4, 8]), torch.Size([4, 8, 8]))
>>> d = torch.dist(a, u@s.diag_embed()@v.transpose(1, 2))
>>> d
tensor(6.1367e-06)
>>> d.backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lart/.conda/envs/pt12/lib/python3.6/site-packages/torch/tensor.py", line 118, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/lart/.conda/envs/pt12/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>> a.requires_grad_(True)
tensor([[[ 5.0926e-01, -5.7983e-01, 7.3646e-01, 8.3811e-01, 1.1680e-01,
-1.3508e+00, -5.9514e-01, -4.2688e-01],
[ 9.9339e-01, 5.0686e-01, -9.3109e-01, 1.6483e+00, 1.8901e-01,
-4.1816e-01, -1.5053e+00, -8.3305e-01],
[-1.0060e+00, 5.2788e-01, 7.4413e-01, 5.1731e-01, -2.1892e+00,
-2.0613e+00, -9.1691e-01, -8.0884e-02],
[-6.8767e-01, -2.8735e+00, 7.3318e-01, 3.1855e-03, 3.2300e-01,
3.1714e-01, -3.6759e-01, 1.0884e+00],
[-1.1363e-01, 4.7537e-01, -5.2829e-01, 7.2855e-02, -1.9361e-01,
7.7375e-01, -5.7489e-01, -3.1300e-01],
[-7.4154e-01, -6.4932e-01, -1.7089e+00, -3.0536e+00, -5.8849e-01,
-8.0302e-01, 8.5439e-01, -3.2681e-01],
[-1.7654e+00, 1.6885e+00, -7.8303e-01, -1.1855e+00, -5.2198e-01,
-1.6007e+00, -4.4931e-01, 1.3803e+00],
[ 9.1247e-01, -9.5255e-01, 1.4173e+00, 4.8889e-01, 3.5064e-01,
3.5041e-02, 5.3359e-01, -1.0892e+00]],
[[-7.0987e-01, -5.0227e-01, 7.3428e-01, -1.4305e+00, -8.9247e-01,
1.0328e+00, 7.6211e-01, -9.2013e-01],
[ 1.8573e+00, -1.9633e-01, -2.2052e-01, 7.8514e-01, -7.0751e-01,
1.3906e+00, -1.8284e+00, 5.4344e-02],
[ 5.9049e-01, 2.8572e-01, 1.0436e+00, 1.4244e+00, 6.6587e-01,
-9.6773e-01, 5.7693e-02, 6.8744e-01],
[ 1.8168e+00, -8.1121e-01, -9.6194e-01, -4.5015e-01, 1.0638e+00,
2.2808e+00, -1.7837e-01, 5.9948e-01],
[-7.6185e-02, 5.2174e-01, 1.0115e+00, 1.6414e+00, 1.0200e-02,
-3.3982e-01, 2.4742e+00, 5.1641e-01],
[-7.8309e-01, -3.2473e-01, 4.3643e-01, -2.4305e+00, 2.2542e+00,
1.2450e-01, -1.1120e+00, -8.5926e-01],
[-3.5841e-01, -9.4043e-02, 5.5771e-02, 9.6267e-01, 6.9952e-03,
-3.7878e-01, -4.0624e-01, 1.2696e-01],
[ 2.1680e+00, -2.0111e-01, -5.6822e-02, 5.3845e-01, -4.1230e-02,
3.1459e+00, -1.0781e+00, 1.1412e+00]],
[[ 4.7876e-02, 1.2553e+00, 2.3790e+00, 9.7567e-01, 8.4884e-01,
2.5797e-02, -5.1782e-01, -1.6589e+00],
[ 1.9872e-01, 2.6252e-01, 1.5556e+00, -5.1998e-01, 4.8064e-02,
-2.9763e-01, 4.0865e-01, -7.6408e-01],
[-1.2351e+00, -9.3510e-01, -3.3864e-01, 1.5223e-01, -1.0920e+00,
-1.4765e+00, 1.6345e+00, -7.3910e-01],
[ 4.4334e-01, 9.9712e-01, -1.7089e+00, 6.7556e-01, 3.3336e-01,
-1.1793e+00, 2.1877e-01, -1.1102e+00],
[-8.3336e-01, 1.0436e-02, 3.1273e-01, -8.4080e-02, -2.3449e-01,
-1.1146e+00, 4.9459e-01, 3.7987e-01],
[-1.6838e+00, 7.7992e-01, -9.8158e-01, -1.7515e+00, 3.5891e-01,
-1.5335e-01, -1.7358e+00, 4.6637e-01],
[-1.5487e+00, -2.7994e-01, 4.7985e-01, 2.0904e-01, 5.8344e-01,
6.5077e-01, 1.2345e+00, 2.0636e+00],
[ 1.0334e+00, 1.0952e-01, 8.3518e-01, 4.5615e-01, 5.1996e-01,
-1.9327e-02, -1.5506e+00, 6.6045e-01]],
[[-1.5528e+00, -7.9126e-01, -2.2267e+00, -1.0876e+00, 1.1706e+00,
1.1638e+00, -8.7892e-01, -7.9665e-02],
[-1.6719e+00, 3.3261e-01, 1.3548e+00, 3.3109e-01, 1.2272e+00,
-2.0417e-01, 6.9749e-01, 8.0466e-01],
[ 1.4222e+00, -1.8551e+00, -5.3144e-01, 1.1752e+00, 5.2877e-01,
5.3952e-01, 6.6802e-01, -2.2321e-01],
[-2.2727e-01, -1.1718e+00, -1.0929e+00, -8.2438e-02, -2.8256e-01,
-1.4067e+00, -9.2816e-01, -1.1016e+00],
[ 8.6106e-01, -8.1874e-02, -5.6385e-02, -2.8466e+00, -1.8488e-01,
-3.0176e-01, -4.6510e-01, 3.4096e-01],
[ 5.0758e-02, 2.9008e-01, 1.1458e+00, -1.4306e-01, 4.3022e-01,
-2.1033e+00, 1.0509e+00, 2.7715e-01],
[ 4.5102e-01, -1.0666e+00, 1.4706e+00, 3.0153e-01, 1.1718e+00,
-4.4404e-01, -1.3700e-01, 2.1973e+00],
[ 8.6106e-01, 9.2344e-01, -1.4858e-02, 2.7544e-03, 4.6628e-01,
-9.7355e-01, -3.2367e-01, 1.3770e+00]]], requires_grad=True)
>>> a.requires_grad
True
>>> a.grad
>>> u, s, v = torch.svd(a)
>>> d = torch.dist(a, u@s.diag_embed()@v.transpose(1, 2))
>>> d
tensor(6.1367e-06, grad_fn=<DistBackward>)
>>> d.backward()
>>> a.grad
tensor([[[ 1.8626e-08, -3.7253e-08, 8.9407e-08, -5.4017e-08, 7.4506e-09,
-4.8429e-08, -5.2154e-08, 0.0000e+00],
[ 5.7742e-08, 7.4506e-08, -1.9372e-07, 1.0803e-07, 3.7253e-09,
3.7253e-09, -6.3330e-08, -1.1921e-07],
[-7.4506e-09, 2.6077e-08, 4.6566e-08, -4.8429e-08, -1.4901e-08,
-3.7253e-08, 1.4901e-08, 1.6764e-08],
[ 4.0978e-08, 2.2352e-08, -2.2352e-08, 4.4034e-08, 3.7253e-09,
7.4506e-09, -1.4901e-08, -2.2352e-08],
[-2.0489e-08, 2.0489e-08, -6.7055e-08, 7.6834e-09, -1.1176e-08,
3.3528e-08, 1.1176e-08, -2.7940e-09],
[-3.7253e-09, 8.3819e-08, -1.1921e-07, 3.7253e-08, -5.9605e-08,
-5.2154e-08, 7.4506e-09, -5.9605e-08],
[-2.9802e-08, 7.8231e-08, -1.1176e-08, -5.2154e-08, -3.3528e-08,
-1.4901e-08, 1.0245e-08, 5.5879e-09],
[ 2.2352e-08, -8.1956e-08, 1.5646e-07, -1.4901e-08, 1.3039e-08,
-7.4506e-09, -7.4506e-09, 1.4901e-08]],
[[-4.6566e-08, -3.7253e-08, 3.7253e-08, 1.4901e-08, -3.1665e-08,
-1.4901e-08, 1.1176e-08, 1.8626e-09],
[-1.8626e-08, 1.8626e-08, 7.4506e-09, 2.9802e-08, -2.4214e-08,
-2.9802e-08, 7.4506e-09, 7.4506e-09],
[ 1.1176e-08, 3.3528e-08, 7.4506e-09, 7.4506e-09, 2.9802e-08,
-7.4506e-09, 1.4901e-08, 0.0000e+00],
[-2.2352e-08, -1.4901e-08, -2.9802e-08, -5.9605e-08, 1.4901e-08,
4.4703e-08, -2.9802e-08, -1.4901e-08],
[ 1.8626e-08, 2.6077e-08, 3.7253e-08, -2.2352e-08, 7.4506e-09,
-1.4901e-08, 3.7253e-08, -7.4506e-09],
[-2.7940e-08, -8.3819e-09, 4.4703e-08, 2.9802e-08, 4.4703e-08,
1.8626e-08, -4.4703e-08, 4.4703e-08],
[ 5.5879e-09, -3.7253e-09, 0.0000e+00, 1.4901e-08, 1.4901e-08,
-9.3132e-09, 2.9802e-08, 1.1176e-08],
[ 0.0000e+00, 4.4703e-08, -2.2352e-08, -2.9802e-08, -1.8626e-09,
0.0000e+00, -1.4901e-08, -2.2352e-08]],
[[-1.3039e-08, 0.0000e+00, 3.7253e-08, -2.9802e-08, 0.0000e+00,
2.0489e-08, 7.4506e-09, -1.1176e-08],
[ 6.1467e-08, 2.9802e-08, -4.4703e-08, -5.5879e-09, -1.8626e-08,
-6.7055e-08, -2.9802e-08, -1.4901e-08],
[-2.9802e-08, -7.8231e-08, 3.7253e-09, 5.0291e-08, -4.0978e-08,
1.8626e-08, 8.9407e-08, 0.0000e+00],
[ 2.9802e-08, 2.9802e-08, -2.0489e-08, -1.4901e-08, -7.4506e-09,
-4.8429e-08, 4.1910e-09, -3.7253e-09],
[-1.1176e-08, 1.8626e-09, -1.2107e-08, 3.9116e-08, 1.8626e-09,
-2.9802e-08, 1.4901e-08, 2.0489e-08],
[-6.7055e-08, 7.4506e-08, 1.8626e-09, -9.6858e-08, 1.3039e-08,
-2.0489e-08, -4.4703e-08, -2.2352e-08],
[-6.3330e-08, -7.4506e-09, 5.2154e-08, 1.8626e-08, 0.0000e+00,
-2.2352e-08, 2.6077e-08, 2.2352e-08],
[ 3.3528e-08, 6.7521e-09, 2.2352e-08, 4.4703e-08, 2.6077e-08,
-5.2154e-08, -1.4901e-08, 1.4901e-08]],
[[-8.5682e-08, 1.1548e-07, 2.9802e-08, 1.8626e-08, -3.3528e-08,
1.8626e-08, 1.1176e-08, -1.3411e-07],
[-2.0862e-07, 5.9605e-08, 4.4703e-08, 1.2666e-07, 1.8626e-08,
8.3819e-08, 4.4703e-08, -7.0781e-08],
[ 3.1292e-07, -4.8429e-08, -2.6077e-08, -3.7253e-09, -3.3528e-08,
-1.1548e-07, 5.5879e-09, -6.7055e-08],
[ 1.3784e-07, -7.4506e-08, -8.3819e-09, 7.4506e-09, 4.0978e-08,
-7.4506e-08, 0.0000e+00, -8.9407e-08],
[ 2.2352e-07, 2.2119e-08, -9.3132e-10, 7.4506e-09, -8.5682e-08,
-7.4506e-08, -2.9802e-08, 4.0978e-08],
[-7.4506e-09, -2.4214e-08, 1.1176e-08, -9.1270e-08, 3.8184e-08,
-4.4703e-08, 3.3528e-08, 2.9802e-08],
[ 2.6077e-08, 3.7253e-08, 4.4703e-08, 3.7253e-08, -7.8231e-08,
1.4901e-08, 1.3039e-08, -2.9802e-08],
[ 8.1956e-08, 3.9116e-08, 7.4506e-09, -9.3132e-09, -7.4506e-09,
-1.8626e-09, -1.4901e-08, 7.4506e-09]]])
相关链接