数据读取函数
import numpy as npimport torch as tfrom torch.utils import data
data = np.load('../dataset/dataset_lstm.npz')X_train = data['X_train']X_test = data['X_test']y_train = data['y_train']y_test = data['y_test']print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)X_tensor = t.stack([t.Tensor(x) for x in X_test])y_tensor = t.stack([t.Tensor([y]) for y in y_test])print(X_tensor.shape, y_tensor.shape)# datasetdataset = data.TensorDataset(X_tensor, y_tensor)# loaderdata_loader = data.DataLoader(dataset, batch_size=5, num_workers=8)
(2130475, 20, 3) (2130475,) (558720, 20, 3) (558720,)
# 通过 iter 的方式进行读取
data_iter = iter(data_loader)
for i in range(10):
X_batch, y_batch = next(data_iter)
print(X_batch.shape)
print(y_batch)
# 读取全部
for i, data in enumerate(data_loader, 0):
# get batch
global_step += 1
inputs, labels = data
if i > 10:
break
dataset = Sketchy(img_paths)
data_loader = DataLoader(dataset, batch_size=20, shuffle=False, num_workers=6)
data_iter = iter(data_loader)
batch_id = 0
# 完整遍历一次
try:
while(True):
data_batch, name_batch = next(data_iter)
print(batch_id, data_batch.shape)
except Exception as e:
pass
