参考来源:
CSDN:【pytorch】使用 torch.utils.data.random_split() 划分数据集
pytorch 文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split

1. torch.utils.data.random_split()

  1. torch.utils.data.random_split(
  2. dataset,
  3. lengths,
  4. generator=<torch._C.Generator object>
  5. )

功能:
随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
参数:

  • **dataset** (Dataset) – 要划分的数据集。
  • **lengths** (sequence) – 要划分的长度。
  • **generator** (Generator) – 用于随机排列的生成器。

示例:

  1. import torch
  2. from torch.utils.data import random_split
  3. dataset = range(10)
  4. train_dataset, test_dataset = random_split(
  5. dataset=dataset,
  6. lengths=[7, 3],
  7. generator=torch.Generator().manual_seed(0)
  8. )
  9. print(list(train_dataset))
  10. print(list(test_dataset))
  11. """
  12. output:
  13. [4, 1, 7, 5, 3, 9, 0]
  14. [8, 6, 2]
  15. """

**torch.Generator().manual_seed(0)****torch.manual_seed(0)** 的效果相同,我们验证一下。

  1. import torch
  2. from torch.utils.data import random_split
  3. dataset = range(10)
  4. torch.manual_seed(0)
  5. train_dataset, test_dataset = random_split(
  6. dataset=dataset,
  7. lengths=[7, 3]
  8. )
  9. print(list(train_dataset))
  10. print(list(test_dataset))
  11. """
  12. output:
  13. [4, 1, 7, 5, 3, 9, 0]
  14. [8, 6, 2]
  15. """

2. sklearn.model_selection 中 train_test_split() 函数

博客园:sklearn 的 train_test_split() 各函数参数含义解释(非常全)
CSDN:sklearn.model_selection 中 train_test_split() 函数

在机器学习中,我们通常将原始数据按照比例分割为“测试集”和“训练集”,从 sklearn.model_selection 中调用 train_test_split() 函数。
train_test_split()sklearn.model_selection 中的分离器函数,用于将数组或矩阵划分为训练集和测试集,函数样式为:

  1. X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
  2. train_data,
  3. train_target,
  4. test_size=0.4,
  5. random_state=0,
  6. stratify=y_train,
  7. shuffle
  8. )

参数解释:

  • **train_data**:待划分的样本数据集
  • **train_target**:待划分的对应样本数据的样本标签
  • **test_size**:样本占比,如果是整数的话就是样本的数量。1)浮点数,在 0 ~ 1 之间,表示样本占比(test_size = 0.3,则样本数据中有 30% 的数据作为测试数据,记入 X_test,其余 70% 数据记入 X_train,同时适用于样本标签);2)整数,表示样本数据中有多少数据记入 X_test 中,其余数据记入 X_train
  • **random_state**:随机数种子,种子不同,每次采的样本不一样;种子相同,采的样本不变(random_state 不取,采样数据不同,但 random_state 等于某个值,采样数据相同,取 0 的时候也相同,这可以自己编程尝试下,不过想改变数值也可以设置 random_state = int(time.time())
  • **stratify**:是为了保持 split 前类的分布。比如有 100 个数据,80个属于 A类,20 个属于 B 类。如果train_test_split(... test_size=0.25, stratify = y_all), 那么 split 之后数据如下:
    • training75 个数据,其中 60 个属于 A 类,15 个属于 B 类。
    • testing25 个数据,其中 20 个属于 A 类,5 个属于 B 类。
    • 用了 stratify 参数,training 集和 testing 集的类的比例是 A:B= 4:1,等同于 split 前的比例(80:20)。通常在这种类分布不平衡的情况下会用到 stratify
    • stratify=X 就是按照 X 中的比例分配
    • stratify=y 就是按照 y 中的比例分配
  • **shuffle**:洗牌模式,1)shuffle = False,不打乱样本数据顺序;2)shuffle = True,打乱样本数据顺序。

整体总结起来各个参数的设置及其类型如下:
主要参数说明:

  • *arrays:可以是列表、numpy数组、scipy稀疏矩阵或pandas的数据框
  • test_size:可以为浮点、整数或 None,默认为 None。
    • ①若为浮点时,表示测试集占总样本的百分比。
    • ②若为整数时,表示测试样本样本数。
    • ③若为 None 时,test size 自动设置成 0.25。
  • **train_size**:可以为浮点、整数或 None,默认为 None。
    • ①若为浮点时,表示训练集占总样本的百分比。
    • ②若为整数时,表示训练样本的样本数。
    • ③若为 None 时,train_size 自动被设置成 0.75。
  • **random_state**:可以为整数、RandomState 实例或None,默认为 None。
    • ①若为 None 时,每次生成的数据都是随机,可能不一样。
    • ②若为整数时,每次生成的数据都相同。
  • **stratify**:可以为类似数组或 None。
    • ①若为 None 时,划分出来的测试集或训练集中,其类标签的比例也是随机的。
    • ②若不为 None 时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集。