1. def load_data(mode='train'):
    2. datafile = './work/mnist.json.gz'
    3. print('loading mnist dataset from {} ......'.format(datafile))
    4. # 加载json数据文件
    5. data = json.load(gzip.open(datafile))
    6. print('mnist dataset load done')
    7. # 读取到的数据区分训练集,验证集,测试集
    8. train_set, val_set, eval_set = data
    9. if mode=='train':
    10. # 获得训练数据集
    11. imgs, labels = train_set[0], train_set[1]
    12. elif mode=='valid':
    13. # 获得验证数据集
    14. imgs, labels = val_set[0], val_set[1]
    15. elif mode=='eval':
    16. # 获得测试数据集
    17. imgs, labels = eval_set[0], eval_set[1]
    18. else:
    19. raise Exception("mode can only be one of ['train', 'valid', 'eval']")
    20. print("训练数据集数量: ", len(imgs))
    21. # 校验数据
    22. imgs_length = len(imgs)
    23. assert len(imgs) == len(labels), \
    24. "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))
    25. # 获得数据集长度
    26. imgs_length = len(imgs)
    27. # 定义数据集每个数据的序号,根据序号读取数据
    28. index_list = list(range(imgs_length))
    29. # 读入数据时用到的批次大小
    30. BATCHSIZE = 100
    31. # 定义数据生成器
    32. def data_generator():
    33. if mode == 'train':
    34. # 训练模式下打乱数据
    35. random.shuffle(index_list)
    36. imgs_list = []
    37. labels_list = []
    38. for i in index_list:
    39. # 将数据处理成希望的类型
    40. img = np.array(imgs[i]).astype('float32')
    41. label = np.array(labels[i]).astype('float32')
    42. imgs_list.append(img)
    43. labels_list.append(label)
    44. if len(imgs_list) == BATCHSIZE:
    45. # 获得一个batchsize的数据,并返回
    46. yield np.array(imgs_list), np.array(labels_list)
    47. # 清空数据读取列表
    48. imgs_list = []
    49. labels_list = []
    50. # 如果剩余数据的数目小于BATCHSIZE,
    51. # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
    52. if len(imgs_list) > 0:
    53. yield np.array(imgs_list), np.array(labels_list)
    54. return data_generator

    这种封装好的数据处理生成器的结构值得参考:

    结构第一部分是加载数据,

    通过mode关键字来选择划分数据集的种类

    使用assert来机器校验,查明数据有没有缺失

    通过len(dataset)即数据集长度来获取索引列表

    结构第二部分:

    如果处理的是训练集,需要对数据集进行打乱

    1. random.shuffle.(index)

    打乱训练集数据的下标

    输入和目标两个列表 要用numpy处理成数组,记得要对数据进行格式转化,一般是flaot32,或者float64

    if len(list_input)==batch_size 也就是说如果输入数据数量达到批次数量

    1. 使用生成器输出 yield np.array(list)也就是说生成器最终还是要转化为ndarray类型方便模型处理

    如果剩余数据的长度小于批次长度

    就一次性使用生成器,把它输出出去