https://github.com/JaMesLiMers/Frame_Video_Prediction_Pytorch/blob/master/dataloader/MovingMNIST/MovingMNIST.py

    1. from __future__ import print_function
    2. import torch.utils.data as data
    3. from PIL import Image
    4. import os
    5. import os.path
    6. import errno
    7. import numpy as np
    8. import torch
    9. import codecs
    10. from torchvision import transforms
    11. class MovingMNIST(data.Dataset):
    12. """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
    13. Args:
    14. root (string): Root directory of dataset where ``processed/training.pt``
    15. and ``processed/test.pt`` exist.
    16. train (bool, optional): If True, creates dataset from ``training.pt``,
    17. otherwise from ``test.pt``.
    18. split (int, optional): Train/test split size. Number defines how many samples
    19. belong to test set.
    20. download (bool, optional): If true, downloads the dataset from the internet and
    21. puts it in root directory. If dataset is already downloaded, it is not
    22. downloaded again.
    23. transform (callable, optional): A function/transform that takes in an PIL image
    24. and returns a transformed version. E.g, ``transforms.RandomCrop``
    25. target_transform (callable, optional): A function/transform that takes in an PIL
    26. image and returns a transformed version. E.g, ``transforms.RandomCrop``
    27. """
    28. urls = [
    29. 'https://github.com/JaMesLiMers/MovingMNIST/raw/master/mnist_test_seq.npy.gz'
    30. ]
    31. raw_folder = 'raw'
    32. processed_folder = 'processed'
    33. training_file = 'moving_mnist_train.pt'
    34. test_file = 'moving_mnist_test.pt'
    35. def __init__(self, root, train=True, split=1000, transform=None, target_transform=None, download=False):
    36. self.root = os.path.expanduser(root)
    37. self.transform = transform
    38. self.target_transform = target_transform
    39. self.split = split
    40. self.train = train # training set or test set
    41. if download:
    42. self.download()
    43. if not self._check_exists():
    44. raise RuntimeError('Dataset not found.' +
    45. ' You can use download=True to download it')
    46. if self.train:
    47. self.train_data = torch.load(
    48. os.path.join(self.root, self.processed_folder, self.training_file))
    49. else:
    50. self.test_data = torch.load(
    51. os.path.join(self.root, self.processed_folder, self.test_file))
    52. def __getitem__(self, index):
    53. """
    54. Args:
    55. index (int): Index
    56. Returns:
    57. tuple: (seq, target) where sampled sequences are splitted into a seq
    58. and target part
    59. """
    60. if self.train:
    61. seq, target = self.train_data[index, :10], self.train_data[index, 10:]
    62. else:
    63. seq, target = self.test_data[index, :10], self.test_data[index, 10:]
    64. # doing this so that it is consistent with all other datasets
    65. # to return a PIL Image
    66. seq = [Image.fromarray(seq.numpy()[i, :, :], mode='L') for i in range(10)]
    67. target = [Image.fromarray(target.numpy()[i, :, :], mode='L') for i in range(10)]
    68. if self.transform is not None:
    69. seq = torch.stack([self.transform(seq[i]) for i in range(10)])
    70. if self.target_transform is not None:
    71. target = torch.stack([self.target_transform(target[i]) for i in range(10)])
    72. return seq, target
    73. def __len__(self):
    74. if self.train:
    75. return len(self.train_data)
    76. else:
    77. return len(self.test_data)
    78. def _check_exists(self):
    79. return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
    80. os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
    81. def download(self):
    82. """Download the Moving MNIST data if it doesn't exist in processed_folder already."""
    83. from six.moves import urllib
    84. import gzip
    85. if self._check_exists():
    86. return
    87. # download files
    88. try:
    89. os.makedirs(os.path.join(self.root, self.raw_folder))
    90. os.makedirs(os.path.join(self.root, self.processed_folder))
    91. except OSError as e:
    92. if e.errno == errno.EEXIST:
    93. pass
    94. else:
    95. raise
    96. for url in self.urls:
    97. print('Downloading ' + url)
    98. data = urllib.request.urlopen(url)
    99. filename = url.rpartition('/')[2]
    100. file_path = os.path.join(self.root, self.raw_folder, filename)
    101. with open(file_path, 'wb') as f:
    102. f.write(data.read())
    103. with open(file_path.replace('.gz', ''), 'wb') as out_f, \
    104. gzip.GzipFile(file_path) as zip_f:
    105. out_f.write(zip_f.read())
    106. os.unlink(file_path)
    107. # process and save as torch files
    108. print('Processing...')
    109. training_set = torch.from_numpy(
    110. np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split]
    111. )
    112. test_set = torch.from_numpy(
    113. np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:]
    114. )
    115. with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
    116. torch.save(training_set, f)
    117. with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
    118. torch.save(test_set, f)
    119. print('Done!')
    120. def __repr__(self):
    121. fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    122. fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
    123. tmp = 'train' if self.train is True else 'test'
    124. fmt_str += ' Train/test: {}\n'.format(tmp)
    125. fmt_str += ' Root Location: {}\n'.format(self.root)
    126. tmp = ' Transforms (if any): '
    127. fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    128. tmp = ' Target Transforms (if any): '
    129. fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    130. return fmt_str