code’github link: https://github.com/clovaai/wsolevaluation
该博客是ADL的backbone结构的学习笔记:
ResNet50残差块定义:
class Bottleneck(nn.Module):expansion = 4# 表示输出的channel的膨胀系数为4 64-->256def __init__(self, inplanes, planes, stride=1, downsample=None,base_width=64):super(Bottleneck, self).__init__()'''下面分别是三个Conv层1x1, 3x3, 1x1,其中最后一层1x1完成planes*4的out_channel定型如果有下采样需求,在第一层的stride=stride中完成'''width = int(planes * (base_width / 64.))self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)self.bn1 = nn.BatchNorm2d(width)self.conv2 = nn.Conv2d(width, width, 3,stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(width)self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False)# 在最后一层1x1卷积核完成通道的增加(4倍)self.bn3 = nn.BatchNorm2d(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)#如果说残差块输入的channel和输出的channel不一致#如果有下采样需求(stride>1,比如说为2),跑出来的结果就是残差块里的残差# downsample的具体操作在 get_downsampling_layer 函数中体现if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
具体的ResNetADL的网络模型构建:
_ADL_POSITION = [[], [], [], [0], [0, 2]]'''在第几层添加ADL模块,论文中有实验数据表明,由于层数越深,特征图像素点的感受野越大,drop_mask所drop的区域面积也就越大'''class ResNetAdl(nn.Module):def __init__(self, block, layers, num_classes=1000,large_feature_map=False, **kwargs):super(ResNetAdl, self).__init__()self.stride_l3 = 1 if large_feature_map else 2'''large_feature_map=Ture self.stride_l3=1large_feature_map=False self.stride_l3=2特征图会大一些,在原ResNet中作者就这个问题给出了一点说明:'it is slightly better whereas slower to set stride = 1''''self.inplanes = 64#ADL模块的两个参数self.adl_drop_rate = kwargs['adl_drop_rate']self.adl_threshold = kwargs['adl_drop_threshold']self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2,padding=3, bias=False)# 这里可以把7x7的卷积核改为3个3x3的卷积核防止信息的丢失self.bn1 = nn.BatchNorm2d(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0],stride=1,split=_ADL_POSITION[1])self.layer2 = self._make_layer(block, 128, layers[1],stride=2,split=_ADL_POSITION[2])self.layer3 = self._make_layer(block, 256, layers[2],stride=self.stride_l3,split=_ADL_POSITION[3])# it is slightly better whereas slower to set stride = 1self.layer4 = self._make_layer(block, 512, layers[3],stride=1,split=_ADL_POSITION[4])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)initialize_weights(self.modules(), init_mode='xavier')def forward(self, x, labels=None, return_cam=False):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)pre_logit = self.avgpool(x)pre_logit = pre_logit.reshape(pre_logit.size(0), -1)logits = self.fc(pre_logit)if return_cam:feature_map = x.detach().clone()'''返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false这里把第4个残差模块的feature map给clone下来保存'''cam_weights = self.fc.weight[labels]# 再获得相应label对应feature map的权重cams = (cam_weights.view(*feature_map.shape[:2], 1, 1) *feature_map).mean(1, keepdim=False)return camsreturn {'logits': logits}def _make_layer(self, block, planes, blocks, stride, split=None):layers = self._layer(block, planes, blocks, stride)# layer type:list 用于储存网络的模块for pos in reversed(split):# 将ADL模块添加到layers里,分别添加到stage 4 的第一层后面和第三层后面layers.insert(pos + 1, ADL(self.adl_drop_rate, self.adl_threshold))return nn.Sequential(*layers)def _layer(self, block, planes, blocks, stride):downsample = get_downsampling_layer(self.inplanes, block, planes,stride)layers = [block(self.inplanes, planes, stride, downsample)]self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return layersdef get_downsampling_layer(inplanes, block, planes, stride):outplanes = planes * block.expansionif stride == 1 and inplanes == outplanes:returnelse:return nn.Sequential(nn.Conv2d(inplanes, outplanes, 1, stride, bias=False),nn.BatchNorm2d(outplanes),)
总体结构就是在ResNet上稍微做了做修改,应该很好理解。
