项目设计

使用 NAS 进行神经网络搜索共分为两大步:

  1. 搜索出性能最好的 cell(normal cell 和 reduction cell)。
  2. 使用 cell 进行堆叠得到 network 并训练它。

根据上述步骤,设计的项目结构如下:
image.png

  1. cnn 用于搜索单个 cell;
  2. cnn_all 用于存储搜索好的 cell 并进行堆叠得到 network,然后训练 network;
  3. data 用于存储数据集,分为训练数据和验证数据。

代码结构设计

本项目的代码是在 DARTS 模型源码的基础上进行改写的,因此重点讲解与 DARTS 模型不同的部分,关于 DARTS 模型源码的解析,源代码里有较多的注释。
cnn 与 cnn_all 内代码结构非常相似,了解其中一个即可。

  1. architecture.py:网络底层结构,包括正向传播和反向传播。
  2. genotypes.py:搜索完后,存储搜索好的结构进行训练。
  3. model.py 和 model_search.py:DARTS 模型 Cell 单元计算与拼接。
  4. mul_search.py:核心部分,搜索算法的调整与实现。
  5. operations.py:搜索空间的各项操作。
  6. train.py:核心部分,main 函数所在,实现网络搜索和评估过程。
  7. train_cifar100:同上,用 cifar-100 数据集。
  8. train_test.py:同上,用于测试,无作用。
  9. train_test_model.py 和 train_test_model_node.py:改写后的 Cell 单元计算与拼接。
  10. util.py:全局公用函数。
  11. visualize.py:网络层图片生成,用于观察最终训练得到的 normal_cell 和 reduction_cell。

Cell 的定义

在 train_test_model_node.py 中定义 Cell 类:
DARTS 模型通过给每个操作赋予权重,然后加权加和得到混合操作,再对权重进行梯度下降,最后选择其中最好的操作,而这里改成了根据入参直接选择其中一个(其实是根据概率的选择)。

  1. # Cell 定义
  2. class Cell(nn.Module):
  3. def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, genotype):
  4. super(Cell, self).__init__()
  5. self.reduction = reduction
  6. # input nodes的结构固定不变,不参与搜索
  7. # 决定第一个input nodes的结构,取决于前一个cell是否是reduction
  8. if reduction_prev:
  9. self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
  10. else:
  11. #第一个input_nodes是cell k-2的输出,cell k-2的输出通道数为C_prev_prev,所以这里操作的输入通道数为C_prev_prev
  12. self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
  13. self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
  14. self._steps = steps # 每个cell中有4个节点的连接状态待确定
  15. self._multiplier = multiplier
  16. self._ops = nn.ModuleList() # 构建operation的modulelist
  17. self._bns = nn.ModuleList()
  18. # for i in range(self._steps):
  19. # for j in range(2+i):
  20. # stride = 2 if reduction and j < 2 else 1
  21. # op = MixedOp(C, stride)
  22. # self._ops.append(op)
  23. # 不再加权求和计算混合操作
  24. # 而是直接选
  25. if reduction:
  26. op_names, self.indices = zip(*genotype.reduce)
  27. self._concat = genotype.reduce_concat
  28. else:
  29. op_names, self.indices = zip(*genotype.normal)
  30. self._concat = genotype.normal_concat
  31. for i in range(14):
  32. stride = 2 if reduction and self.indices[i] < 2 else 1
  33. op = MixedOp(C, stride, op_names[i])
  34. self._ops.append(op)
  35. #def forward(self, s0, s1, weights):
  36. # cell中的计算过程,前向传播时自动调用
  37. # 可以看到与原来相比少了 weights 参数
  38. def forward(self, s0, s1):
  39. s0 = self.preprocess0(s0)
  40. s1 = self.preprocess1(s1)
  41. states = [s0, s1] # 当前节点的前驱节点
  42. offset = 0
  43. #遍历每个intermediate nodes,得到每个节点的output
  44. for i in range(self._steps):
  45. s = sum(self._ops[offset+j](h) for j, h in enumerate(states))
  46. offset += len(states)
  47. # 把当前节点i的output作为下一个节点的输入
  48. # states中为[s0,s1,b1,b2,b3,b4] b1,b2,b3,b4分别是四个intermediate output的输出
  49. # 对intermediate的output进行concat作为当前cell的输出
  50. states.append(s)
  51. # for i in range(self._steps):
  52. # s = self._ops[2*i](states[self.indices[2*i]])
  53. # s = s + self._ops[2*i+1](states[self.indices[2*i+1]])
  54. # states.append(s)
  55. #return torch.cat([states[i] for i in self._concat], dim=1)
  56. # dim=1是指对通道这个维度concat,所以输出的通道数变成原来的4倍
  57. return torch.cat(states[-self._multiplier:], dim=1)

混合操作

在 train_test_model_node.py 中定义 MixedOp 类,前面提到的 DARTS 模型的混合操作被改为根据概率选择,但仍沿用了类名 MixedOp:

  1. # 混合操作
  2. class MixedOp(nn.Module):
  3. def __init__(self, C, stride, primitive):
  4. super(MixedOp, self).__init__()
  5. self._ops = nn.ModuleList()
  6. # 不再采用 DARTS 模型里通过给每个操作赋予权重
  7. # 然后加权求和得到混合操作的方式
  8. # 而是直接选择其中一个
  9. op = OPS[primitive](C, stride, False)
  10. if 'pool' in primitive:
  11. op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False)) # 给池化操作后面加一个batchnormalization
  12. self._ops.append(op)
  13. # def forward(self, x, weights):
  14. # return sum(w * op(x) for w, op in zip(weights, self._ops))
  15. def forward(self, x):
  16. op = self._ops[0]
  17. return op(x)

Cell 的搜索

DARTS 模型通过梯度下降搜索 Cell 结构,需要面对一个双层优化问题,经过一系列近似操作,也给出了近似算法,但本项目是基于概率进行搜索的,在 mul_search.py 中实现了搜索策略的调整,代码较长,这里仅理一下逻辑。

  1. # 导入相关库
  2. import ...
  3. # 定义 MUL 类
  4. class MUL(nn.Module):
  5. # 初始化
  6. def __init__(self, steps=4, multiplier=4):
  7. super(MUL, self).__init__()
  8. self._steps = steps
  9. self._multiplier = multiplier
  10. self.base = 0.02
  11. self._initialize_alphas()
  12. # 初始化 alpha
  13. def _initialize_alphas(self):
  14. ......
  15. # normal_cell 的 alpha 初始化
  16. self.alphas_normal = {"opt":np.zeros((k, num_ops)),
  17. "epoch":np.zeros((k, num_ops)),
  18. "accurcy":np.zeros((k, num_ops))
  19. }
  20. for x in range(k):
  21. for j in range(num_ops):
  22. self.alphas_normal["opt"][x][j]=1/7.0 # 等概率 1/7
  23. self.alphas_normal["opt"][x][0]=0.0
  24. # reduction_cell 的 alpha 初始化
  25. self.alphas_reduce
  26. # normal_cell 的 edge 初始化
  27. self.edge_normal = {"edge":np.array([0.5,0.5,1/3.0,1/3.0,1/3.0,0.25,0.25,0.25,0.25,0.2,0.2,0.2,0.2,0.2]),
  28. "epoch":np.zeros((k,1)),
  29. "accurcy":np.zeros((k,1))
  30. }
  31. # reduction_cell 的 edge 初始化
  32. self.edge_reduce = ......
  33. # 单次搜索两个节点
  34. def genotype(self):
  35. #
  36. def _parse(weights,pros):
  37. ......
  38. return gene
  39. .......
  40. return genotype
  41. # 搜索边
  42. def genotype_edge(self):
  43. def _parse(weights,pros):
  44. # 更新概率,需要依靠 accurcy 和 epoch 表
  45. def update_probability(self,accurcy,genotype):
  46. # 更新 accurcy 和 epoch 表
  47. def updating(a1, a2, b1, b2):
  48. def update(accurcy,weightN,weightR):
  49. # 更新选择每条边的概率
  50. def update_probability_edge(self,accurcy,genotype):
  51. def renew():
  52. def update(accurcy,weightN,weightR):
  53. # 保存参数
  54. def save(self):
  55. # 加载参数
  56. def load(self):
  57. # 搜索全部节点
  58. def genotype_all(self):
  59. #
  60. def _parse(weights,pros):

Cell 的训练

在 train.py 中的 main 函数,损失函数和优化器定义如下:

  1. # 交叉熵损失函数
  2. criterion = nn.CrossEntropyLoss()
  3. criterion = criterion.cuda()
  4. # 用于优化权重w的带动量的SGD优化器
  5. optimizer = torch.optim.SGD(
  6. model.parameters(),
  7. args.learning_rate,
  8. momentum=args.momentum,
  9. weight_decay=args.weight_decay
  10. )

训练集与验证集的拆分:

  1. # 原来训练集的前半部分作为现在的训练集
  2. train_queue = torch.utils.data.DataLoader(
  3. train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
  4. # 原来训练集的后半部分作为现在的验证集
  5. valid_queue = torch.utils.data.DataLoader(
  6. valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
  7. # 优化权重w时,学习率调整用的是余弦退火(SGDR),但只训练50个epoch,其实就相当于cos学习率衰减,没有周期变化
  8. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

迭代训练:

  1. for epoch in range(args.epochs):
  2. #scheduler.step()
  3. logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
  4. model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
  5. # 进行训练
  6. train_acc, train_obj = train(train_queue, model, criterion, optimizer)
  7. logging.info('train_acc %f', train_acc)
  8. # 进行验证
  9. with torch.no_grad():
  10. valid_acc, valid_obj = infer(valid_queue, model, criterion)
  11. logging.info('valid_acc %f', valid_acc)
  12. # 保存一下模型
  13. scheduler.step()
  14. utils.save(model, os.path.join(args.save, 'weights.pt'))

验证函数和验证函数:

  1. def train(train_queue, model, criterion, optimizer):
  2. def infer(valid_queue, model, criterion):

使用训练集进行训练,包括前向传播和反向传播;验证使用验证集,相同的损失函数。

Cell 堆叠构建 Network

在 train_test_model_node.py 中定义 Network 类:

  1. # 神经网络定义
  2. class Network(nn.Module):
  3. def __init__(self, C, num_classes, layers, genotype, steps=4, multiplier=4, stem_multiplier=3):
  4. super(Network, self).__init__()
  5. self._C = C #初始通道数
  6. self._num_classes = num_classes
  7. self._layers = layers
  8. #self._criterion = criterion
  9. self._steps = steps # 一个基本单元cell内有4个节点需要进行operation操作的搜索
  10. self._multiplier = multiplier
  11. C_curr = stem_multiplier*C # 当前Sequential模块的输出通道数
  12. self.stem = nn.Sequential(
  13. nn.Conv2d(3, C_curr, 3, padding=1, bias=False), #前三个参数分别是输入图片的通道数,卷积核的数量,卷积核的大小
  14. nn.BatchNorm2d(C_curr) # BatchNorm2d对minibatch 3d数据组成的4d输入进行batchnormalization操作,num_features为(N,C,H,W)的C
  15. )
  16. C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
  17. self.cells = nn.ModuleList() # 创建一个空modulelist类型数据
  18. reduction_prev = False # 连接的前一个cell是否是reduction cell
  19. for i in range(layers): # 网络是8层,在1/3和2/3位置是reduction cell 其他是normal cell,reduction cell的stride是2
  20. if i in [layers//3, 2*layers//3]: # 对应论文的Cells located at the 1/3 and 2/3 of the total depth of the network are reduction cells
  21. C_curr *= 2
  22. reduction = True
  23. else:
  24. reduction = False
  25. # 构建cell
  26. # 每个cell的input nodes是前前cell和前一个cell的输出
  27. cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, genotype)
  28. reduction_prev = reduction
  29. self.cells += [cell]
  30. # C_prev=multiplier*C_curr是因为每个cell的输出是4个中间节点concat的,这个concat是在通道这个维度,所以输出的通道数变为原来的4倍
  31. C_prev_prev, C_prev = C_prev, multiplier*C_curr
  32. self.global_pooling = nn.AdaptiveAvgPool2d(1) #构建一个平均池化层,output size是1x1
  33. self.classifier = nn.Linear(C_prev, num_classes) #构建一个线性分类器
  34. def forward(self, input):
  35. s0 = s1 = self.stem(input)
  36. for i, cell in enumerate(self.cells):
  37. s0, s1 = s1, cell(s0, s1)
  38. out = self.global_pooling(s1)
  39. logits = self.classifier(out.view(out.size(0),-1))
  40. return logits

Network 训练

与 Cell 的训练差不多,不再赘述。

工具类解析

在 util.py 中定义了许多全局公用函数,仅说明作用:

  1. import os
  2. import numpy as np
  3. import torch
  4. import shutil
  5. import torchvision.transforms as transforms
  6. # 用于计算平均值
  7. class AvgrageMeter(object):
  8. def __init__(self):
  9. def reset(self):
  10. def update(self, val, n=1):
  11. # 求top-k精度
  12. def accuracy(output, target, topk=(1,)):
  13. # 数据增强:Cutout,生成一个边长为length的正方形遮掩(越过边界的话就变成矩形了)
  14. class Cutout(object):
  15. # 用于CIFAR的数据增强操作
  16. def _data_transforms_cifar10(args):
  17. # 统计参数量(MB)
  18. def count_parameters_in_MB(model):
  19. # 保存checkpoint,同时如果是最好模型的话也会copy一下
  20. def save_checkpoint(state, is_best, save):
  21. # 保存模型
  22. def save(model, model_path):
  23. # 载入模型
  24. def load(model, model_path):
  25. # 随机丢弃路径,来自FractalNet
  26. def drop_path(x, drop_prob):
  27. # 创建文件夹,copy文件的一些操作
  28. def create_exp_dir(path, scripts_to_save=None):