def load_data(mode='train'):
datafile = './work/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
# 加载json数据文件
data = json.load(gzip.open(datafile))
print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集
train_set, val_set, eval_set = data
if mode=='train':
# 获得训练数据集
imgs, labels = train_set[0], train_set[1]
elif mode=='valid':
# 获得验证数据集
imgs, labels = val_set[0], val_set[1]
elif mode=='eval':
# 获得测试数据集
imgs, labels = eval_set[0], eval_set[1]
else:
raise Exception("mode can only be one of ['train', 'valid', 'eval']")
print("训练数据集数量: ", len(imgs))
# 校验数据
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))
# 获得数据集长度
imgs_length = len(imgs)
# 定义数据集每个数据的序号,根据序号读取数据
index_list = list(range(imgs_length))
# 读入数据时用到的批次大小
BATCHSIZE = 100
# 定义数据生成器
def data_generator():
if mode == 'train':
# 训练模式下打乱数据
random.shuffle(index_list)
imgs_list = []
labels_list = []
for i in index_list:
# 将数据处理成希望的类型
img = np.array(imgs[i]).astype('float32')
label = np.array(labels[i]).astype('float32')
imgs_list.append(img)
labels_list.append(label)
if len(imgs_list) == BATCHSIZE:
# 获得一个batchsize的数据,并返回
yield np.array(imgs_list), np.array(labels_list)
# 清空数据读取列表
imgs_list = []
labels_list = []
# 如果剩余数据的数目小于BATCHSIZE,
# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
if len(imgs_list) > 0:
yield np.array(imgs_list), np.array(labels_list)
return data_generator
这种封装好的数据处理生成器的结构值得参考:
结构第一部分是加载数据,
通过mode关键字来选择划分数据集的种类
使用assert来机器校验,查明数据有没有缺失
通过len(dataset)即数据集长度来获取索引列表
结构第二部分:
如果处理的是训练集,需要对数据集进行打乱
random.shuffle.(index)
打乱训练集数据的下标
输入和目标两个列表 要用numpy处理成数组,记得要对数据进行格式转化,一般是flaot32,或者float64
if len(list_input)==batch_size 也就是说如果输入数据数量达到批次数量
使用生成器输出 yield np.array(list)也就是说生成器最终还是要转化为ndarray类型方便模型处理
如果剩余数据的长度小于批次长度
就一次性使用生成器,把它输出出去