2D Keras inplement
https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
from keras.utils import conv_utilsfrom keras.engine.topology import Layerimport keras.backend as Kclass PixelShuffler(Layer):def __init__(self, size=(2, 2), data_format=None, **kwargs):super(PixelShuffler, self).__init__(**kwargs)self.data_format = conv_utils.normalize_data_format(data_format)self.size = conv_utils.normalize_tuple(size, 2, 'size')def call(self, inputs):input_shape = K.int_shape(inputs)if len(input_shape) != 4:raise ValueError('Inputs should have rank ' +str(4) +'; Received input shape:', str(input_shape))if self.data_format == 'channels_first':batch_size, c, h, w = input_shapeif batch_size is None:batch_size = -1rh, rw = self.sizeoh, ow = h * rh, w * rwoc = c // (rh * rw)out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))out = K.reshape(out, (batch_size, oc, oh, ow))return outelif self.data_format == 'channels_last':batch_size, h, w, c = input_shapeif batch_size is None:batch_size = -1rh, rw = self.sizeoh, ow = h * rh, w * rwoc = c // (rh * rw)out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))out = K.reshape(out, (batch_size, oh, ow, oc))return outdef compute_output_shape(self, input_shape):if len(input_shape) != 4:raise ValueError('Inputs should have rank ' +str(4) +'; Received input shape:', str(input_shape))if self.data_format == 'channels_first':height = input_shape[2] * self.size[0] if input_shape[2] is not None else Nonewidth = input_shape[3] * self.size[1] if input_shape[3] is not None else Nonechannels = input_shape[1] // self.size[0] // self.size[1]if channels * self.size[0] * self.size[1] != input_shape[1]:raise ValueError('channels of input and size are incompatible')return (input_shape[0],channels,height,width)elif self.data_format == 'channels_last':height = input_shape[1] * self.size[0] if input_shape[1] is not None else Nonewidth = input_shape[2] * self.size[1] if input_shape[2] is not None else Nonechannels = input_shape[3] // self.size[0] // self.size[1]if channels * self.size[0] * self.size[1] != input_shape[3]:raise ValueError('channels of input and size are incompatible')return (input_shape[0],height,width,channels)def get_config(self):config = {'size': self.size,'data_format': self.data_format}base_config = super(PixelShuffler, self).get_config()return dict(list(base_config.items()) + list(config.items()))
https://github.com/Nico-Curti/NumPyNet/blob/master/NumPyNet/layers/shuffler_layer.py
3d pytorch implement
class PixelShuffle3d(nn.Module):'''This class is a 3d version of pixelshuffle.'''def __init__(self, scale):''':param scale: upsample scale'''super().__init__()self.scale = scaledef forward(self, input):batch_size, channels, in_depth, in_height, in_width = input.size()nOut = channels // self.scale ** 3out_depth = in_depth * self.scaleout_height = in_height * self.scaleout_width = in_width * self.scaleinput_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, self.scale, in_depth, in_height, in_width)output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()return output.view(batch_size, nOut, out_depth, out_height, out_width)
