https://github.com/JaMesLiMers/Frame_Video_Prediction_Pytorch/blob/master/dataloader/MovingMNIST/MovingMNIST.py
from __future__ import print_functionimport torch.utils.data as datafrom PIL import Imageimport osimport os.pathimport errnoimport numpy as npimport torchimport codecsfrom torchvision import transformsclass MovingMNIST(data.Dataset): """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset. Args: root (string): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. split (int, optional): Train/test split size. Number defines how many samples belong to test set. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ urls = [ 'https://github.com/JaMesLiMers/MovingMNIST/raw/master/mnist_test_seq.npy.gz' ] raw_folder = 'raw' processed_folder = 'processed' training_file = 'moving_mnist_train.pt' test_file = 'moving_mnist_test.pt' def __init__(self, root, train=True, split=1000, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.split = split self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') if self.train: self.train_data = torch.load( os.path.join(self.root, self.processed_folder, self.training_file)) else: self.test_data = torch.load( os.path.join(self.root, self.processed_folder, self.test_file)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (seq, target) where sampled sequences are splitted into a seq and target part """ if self.train: seq, target = self.train_data[index, :10], self.train_data[index, 10:] else: seq, target = self.test_data[index, :10], self.test_data[index, 10:] # doing this so that it is consistent with all other datasets # to return a PIL Image seq = [Image.fromarray(seq.numpy()[i, :, :], mode='L') for i in range(10)] target = [Image.fromarray(target.numpy()[i, :, :], mode='L') for i in range(10)] if self.transform is not None: seq = torch.stack([self.transform(seq[i]) for i in range(10)]) if self.target_transform is not None: target = torch.stack([self.target_transform(target[i]) for i in range(10)]) return seq, target def __len__(self): if self.train: return len(self.train_data) else: return len(self.test_data) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) def download(self): """Download the Moving MNIST data if it doesn't exist in processed_folder already.""" from six.moves import urllib import gzip if self._check_exists(): return # download files try: os.makedirs(os.path.join(self.root, self.raw_folder)) os.makedirs(os.path.join(self.root, self.processed_folder)) except OSError as e: if e.errno == errno.EEXIST: pass else: raise for url in self.urls: print('Downloading ' + url) data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) with open(file_path, 'wb') as f: f.write(data.read()) with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read()) os.unlink(file_path) # process and save as torch files print('Processing...') training_set = torch.from_numpy( np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split] ) test_set = torch.from_numpy( np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:] ) with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: torch.save(training_set, f) with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: torch.save(test_set, f) print('Done!') def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) tmp = 'train' if self.train is True else 'test' fmt_str += ' Train/test: {}\n'.format(tmp) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str