PyTorch支持类型

要注意, torch.Tensortorch.FloatTensor 的别名。

类型转换方法

主要包括以下三种途径:

  1. 使用独立的类型函数;
  2. 使用 torch.type() 函数;
  3. 使用 type_as(tesnor) 将张量转换为给定类型的张量tensor

使用独立函数

  1. import torch
  2. tensor = torch.randn(3, 5)
  3. print(tensor)
  4. # torch.long() 将tensor投射为long类型
  5. long_tensor = tensor.long()
  6. print(long_tensor)
  7. # torch.half()将tensor投射为半精度浮点类型
  8. half_tensor = tensor.half()
  9. print(half_tensor)
  10. # torch.int()将该tensor投射为int类型
  11. int_tensor = tensor.int()
  12. print(int_tensor)
  13. # torch.double()将该tensor投射为double类型
  14. double_tensor = tensor.double()
  15. print(double_tensor)
  16. # torch.float()将该tensor投射为float类型
  17. float_tensor = tensor.float()
  18. print(float_tensor)
  19. # torch.char()将该tensor投射为char类型
  20. char_tensor = tensor.char()
  21. print(char_tensor)
  22. # torch.byte()将该tensor投射为byte类型
  23. byte_tensor = tensor.byte()
  24. print(byte_tensor)
  25. # torch.short()将该tensor投射为short类型
  26. short_tensor = tensor.short()
  27. print(short_tensor)
  28. # 输出 ################################################################################
  29. -0.5841 -1.6370 0.1353 0.6334 -3.0761
  30. -0.2628 0.1245 0.8626 0.4095 -0.3633
  31. 1.3605 0.5055 -2.0090 0.8933 -0.6267
  32. [torch.FloatTensor of size 3x5]
  33. 0 -1 0 0 -3
  34. 0 0 0 0 0
  35. 1 0 -2 0 0
  36. [torch.LongTensor of size 3x5]
  37. -0.5840 -1.6367 0.1353 0.6333 -3.0762
  38. -0.2627 0.1245 0.8628 0.4094 -0.3633
  39. 1.3604 0.5054 -2.0098 0.8936 -0.6265
  40. [torch.HalfTensor of size 3x5]
  41. 0 -1 0 0 -3
  42. 0 0 0 0 0
  43. 1 0 -2 0 0
  44. [torch.IntTensor of size 3x5]
  45. -0.5841 -1.6370 0.1353 0.6334 -3.0761
  46. -0.2628 0.1245 0.8626 0.4095 -0.3633
  47. 1.3605 0.5055 -2.0090 0.8933 -0.6267
  48. [torch.DoubleTensor of size 3x5]
  49. -0.5841 -1.6370 0.1353 0.6334 -3.0761
  50. -0.2628 0.1245 0.8626 0.4095 -0.3633
  51. 1.3605 0.5055 -2.0090 0.8933 -0.6267
  52. [torch.FloatTensor of size 3x5]
  53. 0 -1 0 0 -3
  54. 0 0 0 0 0
  55. 1 0 -2 0 0
  56. [torch.CharTensor of size 3x5]
  57. 0 255 0 0 253
  58. 0 0 0 0 0
  59. 1 0 254 0 0
  60. [torch.ByteTensor of size 3x5]
  61. 0 -1 0 0 -3
  62. 0 0 0 0 0
  63. 1 0 -2 0 0
  64. [torch.ShortTensor of size 3x5]

使用 torch.type()

type(new_type=None, async=False)

  1. import torch
  2. tensor = torch.randn(3, 5)
  3. print(tensor)
  4. int_tensor = tensor.type(torch.IntTensor)
  5. print(int_tensor)
  6. # 输出 ################################################################################
  7. -0.4449 0.0332 0.5187 0.1271 2.2303
  8. 1.3961 -0.1542 0.8498 -0.3438 -0.2834
  9. -0.5554 0.1684 1.5216 2.4527 0.0379
  10. [torch.FloatTensor of size 3x5]
  11. 0 0 0 0 2
  12. 1 0 0 0 0
  13. 0 0 1 2 0
  14. [torch.IntTensor of size 3x5]

使用 type_as(a)

  1. import torch
  2. tensor_1 = torch.FloatTensor(5)
  3. tensor_2 = torch.IntTensor([10, 20])
  4. tensor_1 = tensor_1.type_as(tensor_2)
  5. assert isinstance(tensor_1, torch.IntTensor)

参考链接