1、tensor.item() 用于得到 tensor的元素值

    1. # coding:utf-8
    2. import torch
    3. import math
    4. """
    5. 测试-torch Tensor使用
    6. """
    7. dtype = torch.float
    8. device = torch.device("cpu")
    9. x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
    10. y = torch.sin(x)
    11. a = torch.randn((), device=device, dtype=dtype)
    12. b = torch.randn((), device=device, dtype=dtype)
    13. c = torch.randn((), device=device, dtype=dtype)
    14. d = torch.randn((), device=device, dtype=dtype)
    15. learning_rate = 1e-6
    16. for t in range(2000):
    17. y_pred = a+b*x+c*x**2+d*x**3
    18. loss = (y_pred - y).pow(2).sum().item()
    19. if t % 100 == 99:
    20. print(t, loss)
    21. grad_y_pred = 2.0 * (y_pred-y)
    22. grad_a = grad_y_pred.sum()
    23. grad_b = (grad_y_pred * x).sum()
    24. grad_c = (grad_y_pred * x ** 2).sum()
    25. grad_d = (grad_y_pred * x ** 3).sum()
    26. a -= learning_rate * grad_a
    27. b -= learning_rate * grad_b
    28. c -= learning_rate * grad_c
    29. d -= learning_rate * grad_d
    30. print(f"Result : y ={a.item()}+{b.item()}x+{c.item()}x^2+{d.item()}x^3")