1. import torch
  2. import torch.nn as nn
  3. import os
  4. from torchvision.models.resnet import *
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import time
  8. import cv2
  9. from visualize import visualize_grid_attention_v2
  10. class SpatialAttention(nn.Module):
  11. def __init__(self, kernel_size=7):
  12. super(SpatialAttention, self).__init__()
  13. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  14. padding = 3 if kernel_size == 7 else 1
  15. self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  16. self.sigmoid = nn.Sigmoid()
  17. def forward(self, x):
  18. avg_out = torch.mean(x, dim=1, keepdim=True)
  19. max_out, _ = torch.max(x, dim=1, keepdim=True)
  20. x = torch.cat([avg_out, max_out], dim=1)
  21. x = self.conv1(x)
  22. return self.sigmoid(x)
  23. def draw_features(width,height,x,savename):
  24. tic=time.time()
  25. fig = plt.figure(figsize=(16, 16))
  26. fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
  27. for i in range(width*height):
  28. plt.subplot(height,width, i + 1)
  29. plt.axis('off')
  30. img = x[0, i, :, :]
  31. pmin = np.min(img)
  32. pmax = np.max(img)
  33. img = ((img - pmin) / (pmax - pmin + 0.000001))*255 #float?[0?1]??????0-255
  34. img=img.astype(np.uint8) #??unit8
  35. img=cv2.applyColorMap(img, cv2.COLORMAP_JET) #??heat map
  36. img = img[:, :, ::-1]#??cv2?BGR??matplotlib(RGB)??????
  37. plt.imshow(img)
  38. print("{}/{}".format(i,width*height))
  39. fig.savefig(savename, dpi=100)
  40. fig.clf()
  41. plt.close()
  42. print("time:{}".format(time.time()-tic))
  43. make_pic = False
  44. n_classes = 2
  45. class Resnet50_att(nn.Module):
  46. def __init__(self,wsl_path,savepath=None):
  47. super(Resnet50_att, self).__init__()
  48. print("model: resnet50")
  49. checkpoint = os.path.join(wsl_path, 'resnet50_wsl.pth')
  50. model = resnet50()
  51. state_dict = torch.load(checkpoint)
  52. model.load_state_dict(state_dict)
  53. model.fc = torch.nn.Linear(2048, n_classes)
  54. self.model = model
  55. self.savepath = savepath
  56. # ?????????????
  57. # self.ca = ChannelAttention(self.inplanes)
  58. self.sa_1 = SpatialAttention()
  59. # ??????????????????
  60. self.sa_2 = SpatialAttention()
  61. def forward(self, x):
  62. origin_img = x
  63. att1 = 0
  64. att2 = 0
  65. #这段代码仅仅只使用了else后面的内容
  66. if self.savepath: # draw features or not
  67. x = self.model.conv1(x)
  68. draw_features(8, 8, x.cpu().numpy(), "{}/f1_conv1.png".format(self.savepath))
  69. x = self.model.bn1(x)
  70. draw_features(8, 8, x.cpu().numpy(), "{}/f2_bn1.png".format(self.savepath))
  71. x = self.model.relu(x)
  72. draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(self.savepath))
  73. x = self.model.maxpool(x)
  74. draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(self.savepath))
  75. x = self.model.layer1(x)
  76. draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(self.savepath))
  77. x = self.model.layer2(x)
  78. draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(self.savepath))
  79. x = self.model.layer3(x)
  80. draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(self.savepath))
  81. x = self.model.layer4(x)
  82. draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(self.savepath))
  83. draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(self.savepath))
  84. x = self.model.avgpool(x)
  85. plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0])
  86. plt.savefig("{}/f9_avgpool.png".format(self.savepath))
  87. plt.clf()
  88. plt.close()
  89. x = x.view(x.size(0), -1)
  90. x = self.model.fc(x)
  91. plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :])
  92. plt.savefig("{}/f10_fc.png".format(self.savepath))
  93. plt.clf()
  94. plt.close()
  95. else:
  96. x = self.model.conv1(x)
  97. x = self.model.bn1(x)
  98. x = self.model.relu(x)
  99. att1 = self.sa_1(x)
  100. x = att1 * x
  101. # x = self.sa_1(x) * x
  102. x = self.model.maxpool(x)
  103. x = self.model.layer1(x)
  104. # att2 = self.sa_2(x)
  105. # x = att2 * x
  106. x = self.model.layer2(x)
  107. x = self.model.layer3(x)
  108. x = self.model.layer4(x)
  109. # x = self.sa_2(x) * x
  110. x = self.model.avgpool(x)
  111. x = x.view(x.size(0), -1)
  112. x = self.model.fc(x)
  113. #return x
  114. return x,att1
  115. #预训练权重文件的地址
  116. wsl_path=""
  117. model = Resnet50_att(wsl_path=wsl_path)
  118. inputs = ""
  119. outsatt = model(inputs)
  120. att1 = att.detach().cpu().numpy()
  121. save_path_1 = "/media/zhujunjie/dataset/attVision/att1"
  122. for index,image_id in enumerate(img_ids):
  123. img_path = os.path.join(dir_train_img, image_id + '.jpg')
  124. visualize_grid_attention_v2(img_path,
  125. save_path=save_path_1,
  126. attention_mask=att1[index][0],
  127. save_image=True,
  128. save_original_image=True,
  129. quality=100)

清洗数据

  1. for step, batch in enumerate(valloader):
  2. if step % 1000 == 0:
  3. logger.info('Valid step {} of {}'.format(step, len(valloader)))
  4. inputs = batch["image"]
  5. inputs = inputs.to(device, dtype=torch.float)
  6. outputs = model(inputs,meta_inputs)
  7. sfx = nn.Softmax(dim=-1)(outputs)
  8. sfx_cpu = sfx.detach().cpu().numpy()
  9. # sfx = sfx.detach().cpu().numpy()
  10. # ???0??????????1??????
  11. predict_per = sfx.argmax(dim=1)
  12. prob = sfx.max(dim=-1, keepdim=False)[1].detach().cpu().numpy()
  13. truel = batch['target'].detach().cpu().numpy()
  14. #筛选出错误分类,p1-p0>0表示将负例分类成正例,反之,就是正例分类成正例
  15. for i in range(len(predict_per)):
  16. if predict_per[i] != truel[i]:
  17. a = {'image_id': batch['image_id'][i], 'p1-p0': sfx_cpu[i][1] - sfx_cpu[i][0]}
  18. fake_predict = fake_predict.append(a, ignore_index=True)
  19. prob_all.extend(prob)
  20. lable_all.extend(truel)
  21. probability_outfile = os.path.join(WORK_DIR,
  22. 'probability/errprob_attUResnet50_ep{}_fold{}.csv'.format(epoch, fold))
  23. fake_predict.to_csv(probability_outfile, index=False)

将输出的概率文件用excel打开,按照p1-p0列进行排序,选择最大和最小的一部分数据,复制到新的表中,存为csv文件: delImg.csv


# 删除csv中的部分记录

dir_train_img = '/media/zhujunjie/dataset/baMelanoma_data'
path_img = '/media/zhujunjie/dataset/baMelanoma_data/train'
train_csv_path = '/media/zhujunjie/dataset/train.csv'
del_csv_path = '/media/zhujunjie/dataset/delImg.csv'
del_train = pd.read_csv(del_csv_path)
trndf = pd.read_csv(train_csv_path)
print('del_csv_img shape {} {}'.format(*del_train.shape))
print('trn_csv_img shape {} {}'.format(*trndf.shape))
del_pngs = [img_id for img_id in del_train['image_id']]
print('Count of del_pngs : {}'.format(len(del_pngs)))
#删除操作
trndf = trndf.set_index('image_id')
trndf.drop(index=del_pngs,axis=0,inplace=True)
#恢复索引
trndf = trndf.reset_index()
#保存删除了部分数据的新标签文件
print('new_csv_img shape {} {}'.format(*trndf.shape))
trndf.to_csv(os.path.join(dir_train_img,"new_train.csv"), index=False)