Spatial Pyramid Pooling (SPP) layer 常用于 CNN 中,处理不定长大小的图片。

    image.png

    如上图所示,假设经过一些列 conv layers 后,得到一个 [256, x, y] 的 vector。其中 256 是上一个卷积层 filter 的数量,x和y 可取任意大小。

    假设 256 为 1,即 input 为 [x, y]

    • Step 1. 把 [x, y] 分成 16 份,每份取 max,即可得到 [1, 16] 的 vector。
    • Step 2. 把 Step 1 中的 16 换成 4,即可得 [1, 4] 的 vector 。
    • Step 3. 把 Step 1 中的 16 换成 1,即可得 [1, 1] 的 vector 。
    • Step 4. 把 Step 1,2,3 中的 vector 拼接起来,即得 [1, 21] 的 vetor。

    如果 filters 为 256 的话,即可得 [1, 256 * 21] 的 vector。

    代码如下:

    1. import torch
    2. import torch.nn as nn
    3. from torch.nn import init
    4. import functools
    5. from torch.autograd import Variable
    6. import numpy as np
    7. import torch.nn.functional as F
    8. import math
    9. def spatial_pyramid_pool(previous_conv, num_sample, previous_conv_size, out_pool_size):
    10. '''
    11. previous_conv: a tensor vector of previous convolution layer
    12. num_sample: an int number of image in the batch
    13. previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer
    14. out_pool_size: a int vector of expected output size of max pooling layer
    15. returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling
    16. '''
    17. # print(previous_conv.size())
    18. for i in range(len(out_pool_size)):
    19. # print(previous_conv_size)
    20. h_wid = int(math.ceil(previous_conv_size[0] / out_pool_size[i]))
    21. w_wid = int(math.ceil(previous_conv_size[1] / out_pool_size[i]))
    22. h_pad = int((h_wid*out_pool_size[i] - previous_conv_size[0] + 1)/2)
    23. w_pad = int((w_wid*out_pool_size[i] - previous_conv_size[1] + 1)/2)
    24. # print("h_pad:",h_pad)
    25. # print("w_pad:",w_pad)
    26. maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))
    27. x = maxpool(previous_conv)
    28. if(i == 0):
    29. spp = x.view(num_sample,-1)
    30. # print("spp size:",spp.size())
    31. else:
    32. # print("size:",spp.size())
    33. spp = torch.cat((spp,x.view(num_sample,-1)), 1)
    34. return spp

    测试结果:
    image.png

    参考:
    https://github.com/yueruchen/sppnet-pytorch
    https://arxiv.org/abs/1804.02047