数据读取函数

    https://stackoverflow.com/questions/44429199/how-to-load-a-list-of-numpy-arrays-to-pytorch-dataset-loader

    1. import numpy as np
    2. import torch as t
    3. from torch.utils import data
    1. data = np.load('../dataset/dataset_lstm.npz')
    2. X_train = data['X_train']
    3. X_test = data['X_test']
    4. y_train = data['y_train']
    5. y_test = data['y_test']
    6. print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
    7. X_tensor = t.stack([t.Tensor(x) for x in X_test])
    8. y_tensor = t.stack([t.Tensor([y]) for y in y_test])
    9. print(X_tensor.shape, y_tensor.shape)
    10. # dataset
    11. dataset = data.TensorDataset(X_tensor, y_tensor)
    12. # loader
    13. data_loader = data.DataLoader(dataset, batch_size=5, num_workers=8)

    (2130475, 20, 3) (2130475,) (558720, 20, 3) (558720,)

    # 通过 iter 的方式进行读取
    data_iter = iter(data_loader)
    for i in range(10):
        X_batch, y_batch = next(data_iter)
        print(X_batch.shape)
        print(y_batch)
    
    # 读取全部
    for i, data in enumerate(data_loader, 0):
        # get batch
        global_step += 1
        inputs, labels = data
    
        if i > 10:
            break
    
    dataset = Sketchy(img_paths)
    data_loader = DataLoader(dataset, batch_size=20, shuffle=False, num_workers=6)
    data_iter = iter(data_loader)
    batch_id = 0
    # 完整遍历一次
    try:
        while(True):
            data_batch, name_batch = next(data_iter)
            print(batch_id, data_batch.shape)
    except Exception  as e:
        pass