torch.sparse

译者:@王帅

校对者:@Timor

警告:

此 API 目前是实验性的 , 可能会在不久的将来发生变化 .

Torch 支持 COO(rdinate) 格式的稀疏张量 , 还能高效地存储和处理大多数元素为零的 张量 .

一个稀疏张量可以表示为一对稠密张量 : 一个张量的值和一个二维张量的指数 . 通过提供这两个张量以及稀疏张量的大小 (不能从这些张量推断!) , 可以构造一个稀疏张量 . 假设我们要在位置 (0,2) 处定义条目3 , 位置 (1,0) 的条目4 , 位置 (1,2) 的条目5的 稀疏张量 , 我们可以这样写 :

  1. >>> i = torch.LongTensor([[0, 1, 1],
  2. [2, 0, 2]])
  3. >>> v = torch.FloatTensor([3, 4, 5])
  4. >>> torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
  5. 0 0 3
  6. 4 0 5
  7. [torch.FloatTensor of size 2x3]

请注意 , LongTensor 的传入参数不是索引元组的列表 . 如果你想用这种方式编写索引 , 你应该在 将它们传递给稀疏构造函数之前进行转换 :

  1. >>> i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
  2. >>> v = torch.FloatTensor([3, 4, 5 ])
  3. >>> torch.sparse.FloatTensor(i.t(), v, torch.Size([2,3])).to_dense()
  4. 0 0 3
  5. 4 0 5
  6. [torch.FloatTensor of size 2x3]

你还可以构造混合稀疏张量 , 其中只有第一个n维是稀疏的 , 而其余维度是密集的 .

  1. >>> i = torch.LongTensor([[2, 4]])
  2. >>> v = torch.FloatTensor([[1, 3], [5, 7]])
  3. >>> torch.sparse.FloatTensor(i, v).to_dense()
  4. 0 0
  5. 0 0
  6. 1 3
  7. 0 0
  8. 5 7
  9. [torch.FloatTensor of size 5x2]

一个空的稀疏张量可以通过指定它的大小来构造 :

  1. >>> torch.sparse.FloatTensor(2, 3)
  2. SparseFloatTensor of size 2x3 with indices:
  3. [torch.LongTensor with no dimension]
  4. and values:
  5. [torch.FloatTensor with no dimension]

注解:

我们的稀疏张量格式允许非聚合稀疏张量 , 索引可能对应有重复的坐标 ; 在这 种情况下 , 该索引处的值代表所有重复条目值的总和 . 非聚合张量允许我们更 有效地实现确定的操作符 .

在大多数情况下 , 你不必关心稀疏张量是否聚合 , 因为大多数操作在聚合或 不聚合稀疏张量的情况下都会以相同的方式工作 . 但是 , 你可能需要关心两种情况 .

首先 , 如果你反复执行可以产生重复条目的操作 (例如 , torch.sparse.FloatTensor.add()) , 则应适当聚合稀疏张量以防止它们变得太大.

其次 , 一些操作符将根据是否聚合 (例如 , torch.sparse.FloatTensor._values()torch.sparse.FloatTensor._indices() , 还有 torch.Tensor._sparse_mask()) 来生成不同的值 . 这些运算符前面加下划线表示它们揭示 内部实现细节 , 因此应谨慎使 , 因为与聚合的稀疏张量一起工作的代码可能不适用于未聚合的稀疏张量 ; 一般来说 , 在运用这些运算符之前 , 最安全的就是确保是聚合的 .

例如 , 假设我们想直接通过 torch.sparse.FloatTensor._values() 来实现一个操作 . 随着乘法分布的增加 , 标量的乘法可以轻易实现 ; 然而 , 平方根不能直接实现 , sqrt(a + b) != sqrt(a) +sqrt(b) (如果给定一个非聚合张量 , 这将被计算出来 . )

  1. class torch.sparse.FloatTensor
  1. add()
  1. add_()
  1. clone()
  1. dim()
  1. div()
  1. div_()
  1. get_device()
  1. hspmm()
  1. mm()
  1. mul()
  1. mul_()
  1. resizeAs_()
  1. size()
  1. spadd()
  1. spmm()
  1. sspaddmm()
  1. sspmm()
  1. sub()
  1. sub_()
  1. t_()
  1. toDense()
  1. transpose()
  1. transpose_()
  1. zero_()
  1. coalesce()
  1. is_coalesced()
  1. _indices()
  1. _values()
  1. _nnz()