参考来源:
CSDN:maskedfill() - masked_fill() - v1.5.0
torch.Tensor
https://pytorch.org/docs/stable/tensors.html
**torch.Tensor.masked_fill**
(Python method, in torch.Tensor)**torch.Tensor.masked_fill_**
(Python method, in torch.Tensor)**masked_fill_(mask, value)**
- 函数名后面加下划线。in-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值,可以称为原地操作符。**masked_fill(mask, value) -> Tensor**
- 函数名后面没有下划线。out-of-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,经过复制操作,不是直接在原来的内存上改变它的值,而是修改复制的 tensor。
1. maskedfill(mask, value)
功能:
掩码操作。
用 value 填充 tensor 中与 mask 中值为 1 位置相对应的元素。mask 的形状必须与要填充的 tensor 形状一致。
参数:
mask (BoolTensor):mask是一个 pytorch 张量(Tensor),元素是布尔值。
value (float):value 是要填充的值。
注意:
参数 mask 必须与 t 的 size 相同或者两者是可广播(broadcasting-semantics)的。
关于广播(broadcasting-semantics),可参考 pytorch 广播语义(Broadcasting semantics)。
2. masked_fill(mask, value) -> Tensor
Out-of-place version of torch.Tensor.maskedfill()
3. 示例
3.1 masked_fill(mask, value) -> Tensor
import torch
data = torch.randn(2, 3)
print('data:\n', data)
mask = torch.tensor([[True, False, True], [False, True, False]])
print('mask:\n', mask)
masked1 = data.masked_fill(mask, 999)
print('data.masked_fill(mask, value)之后:')
print('masked1:\n', masked1)
print('data:\n', data)
结果:
"""output:
data:
tensor([[ 0.5203, -0.0460, 1.3925],
[ 1.7319, -1.1121, -0.1828]])
mask:
tensor([[ True, False, True],
[False, True, False]])
data.masked_fill(mask, value)之后:
masked1:
tensor([[ 9.9900e+02, -4.6050e-02, 9.9900e+02],
[ 1.7319e+00, 9.9900e+02, -1.8277e-01]])
data:
tensor([[ 0.5203, -0.0460, 1.3925],
[ 1.7319, -1.1121, -0.1828]])
"""
3.2 maskedfill(mask, value)
import torch
data = torch.randn(2, 3)
print('data:\n', data)
mask = torch.tensor([[True, False, True], [False, True, False]])
print('mask:\n', mask)
masked1 = data.masked_fill_(mask, 999)
print('data.masked_fill_(mask, value)之后:')
print('masked1:\n', masked1)
print('data:\n', data)
结果:
"""output:
data:
tensor([[ 0.5494, 0.3472, -0.0760],
[ 0.5392, 0.0113, 0.9853]])
mask:
tensor([[ True, False, True],
[False, True, False]])
data.masked_fill_(mask, value)之后:
masked1:
tensor([[9.9900e+02, 3.4724e-01, 9.9900e+02],
[5.3922e-01, 9.9900e+02, 9.8531e-01]])
data:
tensor([[9.9900e+02, 3.4724e-01, 9.9900e+02],
[5.3922e-01, 9.9900e+02, 9.8531e-01]])
"""