PyTorch提取中间层特征? - 知乎 https://www.zhihu.com/question/68384370
借助PyTorch Hook机制
PyTorch提取中间层特征? - 袁坤的回答 - 知乎 https://www.zhihu.com/question/68384370/answer/419741762
建议使用hook,在不改变网络forward函数的基础上提取所需的特征或者梯度,在调用阶段对module使用即可获得所需梯度或者特征。
inter_feature = {}
inter_gradient = {}
def make_hook(name, flag):
if flag == 'forward':
def hook(m, input, output):
inter_feature[name] = input
return hook
elif flag == 'backward':
def hook(m, input, output):
inter_gradient[name] = output
return hook
else:
assert False
m.register_forward_hook(make_hook(name, 'forward'))
m.register_backward_hook(make_hook(name, 'backward'))
在前向计算和反向计算的时候即可达到类似钩子的作用,中间变量已经被放置于 inter_feature
和 inter_gradient
。
output = model(input) # achieve intermediate feature
loss = criterion(output, target)
loss.backward() # achieve backward intermediate gradients
最后可根据需求是否释放hook。
m.remove()
PyTorch提取中间层特征? - 涩醉的回答 - 知乎 https://www.zhihu.com/question/68384370/answer/751212803
通过pytorch的hook机制简单实现了一下,只输出conv层的特征图。
详细可以看下面的blog:涩醉:pytorch使用hook打印中间特征图、计算网络算力等。
懒得跳转,可以直接看下面这份代码。
import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
def viz(module, input):
x = input[0][0]
#最多显示4张图
min_num = np.minimum(4, x.size()[0])
for i in range(min_num):
plt.subplot(1, 4, i+1)
plt.imshow(x[i])
plt.show()
import cv2
import numpy as np
def main():
t = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True).to(device)
for name, m in model.named_modules():
# if not isinstance(m, torch.nn.ModuleList) and \
# not isinstance(m, torch.nn.Sequential) and \
# type(m) in torch.nn.__dict__.values():
# 这里只对卷积层的feature map进行显示
if isinstance(m, torch.nn.Conv2d):
m.register_forward_pre_hook(viz)
img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')
img = t(img).unsqueeze(0).to(device)
with torch.no_grad():
model(img)
if __name__ == '__main__':
main()
借助模型类的属性传递
PyTorch提取中间层特征? - 登高居士的回答 - 知乎 https://www.zhihu.com/question/68384370/answer/812588336
如何得到中间层特征:如果只想得到中间层特征,而不需要得到 gradient
之类的,那么不需要hook函数这么复杂。只需要在forward函数中添加一行代码,将feature赋值给self变量即可,即 self.feature_map = feature
给一个例子:
# Define a Convolutional Neural Network
class Net(nn.Module):
def __init__(self, kernel_size=5, n_filters=16, n_layers=3):
xxx
def forward(self, x):
x = self.body(self.head(x))
self.featuremap1 = x # 核心代码
return F.relu(self.fc(x))
model_ft = Net()
train_model(model_ft)
feature_output1 = model_ft.featuremap1.transpose(1,0).cpu().detach()
这样就得到了 feature_map
,并保存到了feature_output变量中。如何显示中间层特征:给出一个简单显示代码
def feature_imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.detach().numpy().transpose((1, 2, 0))
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
out = torchvision.utils.make_grid(feature_ouput1)
feature_imshow(out)
结果图如下: