def load_data(mode='train'):datafile = './work/mnist.json.gz'print('loading mnist dataset from {} ......'.format(datafile))# 加载json数据文件data = json.load(gzip.open(datafile))print('mnist dataset load done')# 读取到的数据区分训练集,验证集,测试集train_set, val_set, eval_set = dataif mode=='train':# 获得训练数据集imgs, labels = train_set[0], train_set[1]elif mode=='valid':# 获得验证数据集imgs, labels = val_set[0], val_set[1]elif mode=='eval':# 获得测试数据集imgs, labels = eval_set[0], eval_set[1]else:raise Exception("mode can only be one of ['train', 'valid', 'eval']")print("训练数据集数量: ", len(imgs))# 校验数据imgs_length = len(imgs)assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))# 获得数据集长度imgs_length = len(imgs)# 定义数据集每个数据的序号,根据序号读取数据index_list = list(range(imgs_length))# 读入数据时用到的批次大小BATCHSIZE = 100# 定义数据生成器def data_generator():if mode == 'train':# 训练模式下打乱数据random.shuffle(index_list)imgs_list = []labels_list = []for i in index_list:# 将数据处理成希望的类型img = np.array(imgs[i]).astype('float32')label = np.array(labels[i]).astype('float32')imgs_list.append(img)labels_list.append(label)if len(imgs_list) == BATCHSIZE:# 获得一个batchsize的数据,并返回yield np.array(imgs_list), np.array(labels_list)# 清空数据读取列表imgs_list = []labels_list = []# 如果剩余数据的数目小于BATCHSIZE,# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batchif len(imgs_list) > 0:yield np.array(imgs_list), np.array(labels_list)return data_generator
这种封装好的数据处理生成器的结构值得参考:
结构第一部分是加载数据,
通过mode关键字来选择划分数据集的种类
使用assert来机器校验,查明数据有没有缺失
通过len(dataset)即数据集长度来获取索引列表
结构第二部分:
如果处理的是训练集,需要对数据集进行打乱
random.shuffle.(index)
打乱训练集数据的下标
输入和目标两个列表 要用numpy处理成数组,记得要对数据进行格式转化,一般是flaot32,或者float64
if len(list_input)==batch_size 也就是说如果输入数据数量达到批次数量
使用生成器输出 yield np.array(list)也就是说生成器最终还是要转化为ndarray类型方便模型处理
如果剩余数据的长度小于批次长度
就一次性使用生成器,把它输出出去
