2D Keras inplement
    https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
    image.png

    1. from keras.utils import conv_utils
    2. from keras.engine.topology import Layer
    3. import keras.backend as K
    4. class PixelShuffler(Layer):
    5. def __init__(self, size=(2, 2), data_format=None, **kwargs):
    6. super(PixelShuffler, self).__init__(**kwargs)
    7. self.data_format = conv_utils.normalize_data_format(data_format)
    8. self.size = conv_utils.normalize_tuple(size, 2, 'size')
    9. def call(self, inputs):
    10. input_shape = K.int_shape(inputs)
    11. if len(input_shape) != 4:
    12. raise ValueError('Inputs should have rank ' +
    13. str(4) +
    14. '; Received input shape:', str(input_shape))
    15. if self.data_format == 'channels_first':
    16. batch_size, c, h, w = input_shape
    17. if batch_size is None:
    18. batch_size = -1
    19. rh, rw = self.size
    20. oh, ow = h * rh, w * rw
    21. oc = c // (rh * rw)
    22. out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
    23. out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
    24. out = K.reshape(out, (batch_size, oc, oh, ow))
    25. return out
    26. elif self.data_format == 'channels_last':
    27. batch_size, h, w, c = input_shape
    28. if batch_size is None:
    29. batch_size = -1
    30. rh, rw = self.size
    31. oh, ow = h * rh, w * rw
    32. oc = c // (rh * rw)
    33. out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
    34. out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
    35. out = K.reshape(out, (batch_size, oh, ow, oc))
    36. return out
    37. def compute_output_shape(self, input_shape):
    38. if len(input_shape) != 4:
    39. raise ValueError('Inputs should have rank ' +
    40. str(4) +
    41. '; Received input shape:', str(input_shape))
    42. if self.data_format == 'channels_first':
    43. height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
    44. width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
    45. channels = input_shape[1] // self.size[0] // self.size[1]
    46. if channels * self.size[0] * self.size[1] != input_shape[1]:
    47. raise ValueError('channels of input and size are incompatible')
    48. return (input_shape[0],
    49. channels,
    50. height,
    51. width)
    52. elif self.data_format == 'channels_last':
    53. height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
    54. width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
    55. channels = input_shape[3] // self.size[0] // self.size[1]
    56. if channels * self.size[0] * self.size[1] != input_shape[3]:
    57. raise ValueError('channels of input and size are incompatible')
    58. return (input_shape[0],
    59. height,
    60. width,
    61. channels)
    62. def get_config(self):
    63. config = {'size': self.size,
    64. 'data_format': self.data_format}
    65. base_config = super(PixelShuffler, self).get_config()
    66. return dict(list(base_config.items()) + list(config.items()))

    https://github.com/Nico-Curti/NumPyNet/blob/master/NumPyNet/layers/shuffler_layer.py

    3d pytorch implement

    1. class PixelShuffle3d(nn.Module):
    2. '''
    3. This class is a 3d version of pixelshuffle.
    4. '''
    5. def __init__(self, scale):
    6. '''
    7. :param scale: upsample scale
    8. '''
    9. super().__init__()
    10. self.scale = scale
    11. def forward(self, input):
    12. batch_size, channels, in_depth, in_height, in_width = input.size()
    13. nOut = channels // self.scale ** 3
    14. out_depth = in_depth * self.scale
    15. out_height = in_height * self.scale
    16. out_width = in_width * self.scale
    17. input_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, self.scale, in_depth, in_height, in_width)
    18. output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
    19. return output.view(batch_size, nOut, out_depth, out_height, out_width)