介绍了几乎所有的训练代码,在训练代码中比较重要的步骤应该就是网络结构的搭建,这也是SSD算法的核心。因此接下来介绍的symbol_factory.py就是网络结构搭建的起始脚本,主要包含网络的一些配置信息。

    该脚本主要包含get_config,get_symbol_train,get_symbol三个函数,后面两个函数差不多,可以从get_symbol_train函数开始看,该函数中大部分是检测相关的参数配置,更详细的内容在另一个脚本(symbol_builder.py)中实现。

    1. """Presets for various network configurations"""
    2. from __future__ import absolute_import
    3. import logging
    4. from . import symbol_builder
    5. # 这个函数就是给指定网络和输入数据尺寸生成一些参数配置
    6. def get_config(network, data_shape, **kwargs):
    7. """Configuration factory for various networks
    8. Parameters
    9. ----------
    10. network : str
    11. base network name, such as vgg_reduced, inceptionv3, resnet...
    12. data_shape : int
    13. input data dimension
    14. kwargs : dict
    15. extra arguments
    16. """
    17. if network == 'vgg16_reduced':
    18. if data_shape >= 448:
    19. # 和下面else里面的from_layers做对比可以看出,对于输入图像为512*512的情况,需要额外增加5个卷积层,
    20. # 而如果是300*300(else那部分),则需要额外增加4个卷积层
    21. from_layers = ['relu4_3', 'relu7', '', '', '', '', '']
    22. num_filters = [512, -1, 512, 256, 256, 256, 256]
    23. strides = [-1, -1, 2, 2, 2, 2, 1]
    24. pads = [-1, -1, 1, 1, 1, 1, 1]
    25. sizes = [[.07, .1025], [.15,.2121], [.3, .3674], [.45, .5196], [.6, .6708], \
    26. [.75, .8216], [.9, .9721]]
    27. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    28. [1,2,.5,3,1./3], [1,2,.5], [1,2,.5]]
    29. normalizations = [20, -1, -1, -1, -1, -1, -1] # normalization如果等于-1,表示该层不做normalization
    30. steps = [] if data_shape != 512 else [x / 512.0 for x in
    31. [8, 16, 32, 64, 128, 256, 512]]
    32. else:
    33. # relu4_3表示原来VGG16网络中的第4个块中的第3个卷积,relu7表示原来的VGG16网络中的fc7层(VGG16最后一共有3个fc层,
    34. # 分别为fc6,fc7,fc8),后面的4个‘’表示在VGG16基础上增加的4个卷积层。这6层就是做特征融合的时候要提前特征的层。
    35. from_layers = ['relu4_3', 'relu7', '', '', '', '']
    36. num_filters = [512, -1, 512, 256, 256, 256] # 对应层的卷积核个数,这里-1表示extracted features
    37. strides = [-1, -1, 2, 2, 1, 1]
    38. pads = [-1, -1, 1, 1, 0, 0]
    39. sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]]
    40. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    41. [1,2,.5], [1,2,.5]]
    42. normalizations = [20, -1, -1, -1, -1, -1]
    43. steps = [] if data_shape != 300 else [x / 300.0 for x in [8, 16, 32, 64, 100, 300]]
    44. if not (data_shape == 300 or data_shape == 512):
    45. logging.warn('data_shape %d was not tested, use with caucious.' % data_shape)
    46. return locals() # locals函数返回以上这些变量
    47. elif network == 'inceptionv3':
    48. from_layers = ['ch_concat_mixed_7_chconcat', 'ch_concat_mixed_10_chconcat', '', '', '', '']
    49. num_filters = [-1, -1, 512, 256, 256, 128]
    50. strides = [-1, -1, 2, 2, 2, 2]
    51. pads = [-1, -1, 1, 1, 1, 1]
    52. sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]]
    53. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    54. [1,2,.5], [1,2,.5]]
    55. normalizations = -1
    56. steps = []
    57. return locals()
    58. elif network == 'resnet50':
    59. num_layers = 50
    60. image_shape = '3,224,224' # resnet require it as shape check
    61. network = 'resnet'
    62. from_layers = ['_plus12', '_plus15', '', '', '', '']
    63. num_filters = [-1, -1, 512, 256, 256, 128]
    64. strides = [-1, -1, 2, 2, 2, 2]
    65. pads = [-1, -1, 1, 1, 1, 1]
    66. sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]]
    67. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    68. [1,2,.5], [1,2,.5]]
    69. normalizations = -1
    70. steps = []
    71. return locals()
    72. elif network == 'resnet101':
    73. num_layers = 101
    74. image_shape = '3,224,224'
    75. network = 'resnet'
    76. from_layers = ['_plus12', '_plus15', '', '', '', '']
    77. num_filters = [-1, -1, 512, 256, 256, 128]
    78. strides = [-1, -1, 2, 2, 2, 2]
    79. pads = [-1, -1, 1, 1, 1, 1]
    80. sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]]
    81. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    82. [1,2,.5], [1,2,.5]]
    83. normalizations = -1
    84. steps = []
    85. return locals()
    86. elif network == 'mobilenet':
    87. from_layers = ['activation22', 'activation26', '', '', '', '']
    88. num_filters = [-1, -1, 512, 256, 256, 128]
    89. strides = [-1, -1, 2, 2, 2, 2]
    90. pads = [-1, -1, 1, 1, 1, 1]
    91. sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]]
    92. ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \
    93. [1,2,.5], [1,2,.5]]
    94. normalizations = -1
    95. steps = []
    96. return locals()
    97. else:
    98. msg = 'No configuration found for %s with data_shape %d' % (network, data_shape)
    99. raise NotImplementedError(msg)
    100. # 这个函数就是用来获得整个网络的结构信息,包括新增的用于特征融合的层,生成anchor的层等等。
    101. # 该函数主要是调用symbol_builder.py脚本中的几个函数来执行导入和生成symbol的操作。
    102. def get_symbol_train(network, data_shape, **kwargs):
    103. """Wrapper for get symbol for train
    104. Parameters
    105. ----------
    106. network : str
    107. name for the base network symbol
    108. data_shape : int
    109. input shape
    110. kwargs : dict
    111. see symbol_builder.get_symbol_train for more details
    112. """
    113. if network.startswith('legacy'):
    114. logging.warn('Using legacy model.')
    115. return symbol_builder.import_module(network).get_symbol_train(**kwargs)
    116. # 调用get_config函数得到一些配置参数
    117. config = get_config(network, data_shape, **kwargs).copy()
    118. # 得到的配置参数config再加上kwargs里面的其他配置参数
    119. config.update(kwargs)
    120. # 调用symbol_builder.py脚本中的get_symbol_train函数得到symbol
    121. return symbol_builder.get_symbol_train(**config)
    122. def get_symbol(network, data_shape, **kwargs):
    123. """Wrapper for get symbol for test
    124. Parameters
    125. ----------
    126. network : str
    127. name for the base network symbol
    128. data_shape : int
    129. input shape
    130. kwargs : dict
    131. see symbol_builder.get_symbol for more details
    132. """
    133. if network.startswith('legacy'):
    134. logging.warn('Using legacy model.')
    135. return symbol_builder.import_module(network).get_symbol(**kwargs)
    136. config = get_config(network, data_shape, **kwargs).copy()
    137. config.update(kwargs)
    138. return symbol_builder.get_symbol(**config)