参考来源:
CSDN:maskedfill() - masked_fill() - v1.5.0

torch.Tensor
https://pytorch.org/docs/stable/tensors.html

  • **torch.Tensor.masked_fill** (Python method, in torch.Tensor)
  • **torch.Tensor.masked_fill_** (Python method, in torch.Tensor)

  • **masked_fill_(mask, value)** - 函数名后面加下划线。in-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值,可以称为原地操作符。

  • **masked_fill(mask, value) -> Tensor** - 函数名后面没有下划线。out-of-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,经过复制操作,不是直接在原来的内存上改变它的值,而是修改复制的 tensor。

1. maskedfill(mask, value)

功能:
掩码操作。
用 value 填充 tensor 中与 mask 中值为 1 位置相对应的元素。mask 的形状必须与要填充的 tensor 形状一致。

参数:
mask (BoolTensor):mask是一个 pytorch 张量(Tensor),元素是布尔值。
value (float):value 是要填充的值。

注意:
参数 mask 必须与 t 的 size 相同或者两者是可广播(broadcasting-semantics)的。
关于广播(broadcasting-semantics),可参考 pytorch 广播语义(Broadcasting semantics)

2. masked_fill(mask, value) -> Tensor

Out-of-place version of torch.Tensor.maskedfill()

3. 示例

3.1 masked_fill(mask, value) -> Tensor

  1. import torch
  2. data = torch.randn(2, 3)
  3. print('data:\n', data)
  4. mask = torch.tensor([[True, False, True], [False, True, False]])
  5. print('mask:\n', mask)
  6. masked1 = data.masked_fill(mask, 999)
  7. print('data.masked_fill(mask, value)之后:')
  8. print('masked1:\n', masked1)
  9. print('data:\n', data)

结果:

  1. """output:
  2. data:
  3. tensor([[ 0.5203, -0.0460, 1.3925],
  4. [ 1.7319, -1.1121, -0.1828]])
  5. mask:
  6. tensor([[ True, False, True],
  7. [False, True, False]])
  8. data.masked_fill(mask, value)之后:
  9. masked1:
  10. tensor([[ 9.9900e+02, -4.6050e-02, 9.9900e+02],
  11. [ 1.7319e+00, 9.9900e+02, -1.8277e-01]])
  12. data:
  13. tensor([[ 0.5203, -0.0460, 1.3925],
  14. [ 1.7319, -1.1121, -0.1828]])
  15. """

3.2 maskedfill(mask, value)

  1. import torch
  2. data = torch.randn(2, 3)
  3. print('data:\n', data)
  4. mask = torch.tensor([[True, False, True], [False, True, False]])
  5. print('mask:\n', mask)
  6. masked1 = data.masked_fill_(mask, 999)
  7. print('data.masked_fill_(mask, value)之后:')
  8. print('masked1:\n', masked1)
  9. print('data:\n', data)

结果:

  1. """output:
  2. data:
  3. tensor([[ 0.5494, 0.3472, -0.0760],
  4. [ 0.5392, 0.0113, 0.9853]])
  5. mask:
  6. tensor([[ True, False, True],
  7. [False, True, False]])
  8. data.masked_fill_(mask, value)之后:
  9. masked1:
  10. tensor([[9.9900e+02, 3.4724e-01, 9.9900e+02],
  11. [5.3922e-01, 9.9900e+02, 9.8531e-01]])
  12. data:
  13. tensor([[9.9900e+02, 3.4724e-01, 9.9900e+02],
  14. [5.3922e-01, 9.9900e+02, 9.8531e-01]])
  15. """