https://github.com/JaMesLiMers/Frame_Video_Prediction_Pytorch/blob/master/dataloader/MovingMNIST/MovingMNIST.py
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
from torchvision import transforms
class MovingMNIST(data.Dataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
split (int, optional): Train/test split size. Number defines how many samples
belong to test set.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in an PIL
image and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
urls = [
'https://github.com/JaMesLiMers/MovingMNIST/raw/master/mnist_test_seq.npy.gz'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'moving_mnist_train.pt'
test_file = 'moving_mnist_test.pt'
def __init__(self, root, train=True, split=1000, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.split = split
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (seq, target) where sampled sequences are splitted into a seq
and target part
"""
if self.train:
seq, target = self.train_data[index, :10], self.train_data[index, 10:]
else:
seq, target = self.test_data[index, :10], self.test_data[index, 10:]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
seq = [Image.fromarray(seq.numpy()[i, :, :], mode='L') for i in range(10)]
target = [Image.fromarray(target.numpy()[i, :, :], mode='L') for i in range(10)]
if self.transform is not None:
seq = torch.stack([self.transform(seq[i]) for i in range(10)])
if self.target_transform is not None:
target = torch.stack([self.target_transform(target[i]) for i in range(10)])
return seq, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
"""Download the Moving MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)
# process and save as torch files
print('Processing...')
training_set = torch.from_numpy(
np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[:-self.split]
)
test_set = torch.from_numpy(
np.load(os.path.join(self.root, self.raw_folder, 'mnist_test_seq.npy')).swapaxes(0, 1)[-self.split:]
)
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Train/test: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str