参考来源:
CSDN:【pytorch】使用 torch.utils.data.random_split() 划分数据集
pytorch 文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split
1. torch.utils.data.random_split()
torch.utils.data.random_split(
dataset,
lengths,
generator=<torch._C.Generator object>
)
功能:
随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
参数:
**dataset**
(Dataset) – 要划分的数据集。**lengths**
(sequence) – 要划分的长度。**generator**
(Generator) – 用于随机排列的生成器。
示例:
import torch
from torch.utils.data import random_split
dataset = range(10)
train_dataset, test_dataset = random_split(
dataset=dataset,
lengths=[7, 3],
generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))
"""
output:
[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]
"""
**torch.Generator().manual_seed(0)**
和 **torch.manual_seed(0)**
的效果相同,我们验证一下。
import torch
from torch.utils.data import random_split
dataset = range(10)
torch.manual_seed(0)
train_dataset, test_dataset = random_split(
dataset=dataset,
lengths=[7, 3]
)
print(list(train_dataset))
print(list(test_dataset))
"""
output:
[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]
"""
2. sklearn.model_selection 中 train_test_split() 函数
博客园:sklearn 的 train_test_split() 各函数参数含义解释(非常全)
CSDN:sklearn.model_selection 中 train_test_split() 函数
在机器学习中,我们通常将原始数据按照比例分割为“测试集”和“训练集”,从 sklearn.model_selection
中调用 train_test_split()
函数。train_test_split()
是 sklearn.model_selection
中的分离器函数,用于将数组或矩阵划分为训练集和测试集,函数样式为:
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
train_data,
train_target,
test_size=0.4,
random_state=0,
stratify=y_train,
shuffle
)
参数解释:
**train_data**
:待划分的样本数据集**train_target**
:待划分的对应样本数据的样本标签**test_size**
:样本占比,如果是整数的话就是样本的数量。1)浮点数,在0 ~ 1
之间,表示样本占比(test_size = 0.3
,则样本数据中有30%
的数据作为测试数据,记入X_test
,其余70%
数据记入X_train
,同时适用于样本标签);2)整数,表示样本数据中有多少数据记入X_test
中,其余数据记入X_train
。**random_state**
:随机数种子,种子不同,每次采的样本不一样;种子相同,采的样本不变(random_state
不取,采样数据不同,但random_state
等于某个值,采样数据相同,取 0 的时候也相同,这可以自己编程尝试下,不过想改变数值也可以设置random_state = int(time.time())
)**stratify**
:是为了保持split
前类的分布。比如有100
个数据,80
个属于A
类,20
个属于B
类。如果train_test_split(... test_size=0.25, stratify = y_all)
, 那么split
之后数据如下:training
:75
个数据,其中60
个属于A
类,15
个属于B
类。testing
:25
个数据,其中20
个属于A
类,5
个属于B
类。- 用了
stratify
参数,training
集和testing
集的类的比例是A:B= 4:1
,等同于split
前的比例(80:20
)。通常在这种类分布不平衡的情况下会用到stratify
。 - 将
stratify=X
就是按照X
中的比例分配 - 将
stratify=y
就是按照y
中的比例分配
**shuffle**
:洗牌模式,1)shuffle = False
,不打乱样本数据顺序;2)shuffle = True
,打乱样本数据顺序。
整体总结起来各个参数的设置及其类型如下:
主要参数说明:
- *arrays:可以是列表、numpy数组、scipy稀疏矩阵或pandas的数据框
- test_size:可以为浮点、整数或 None,默认为 None。
- ①若为浮点时,表示测试集占总样本的百分比。
- ②若为整数时,表示测试样本样本数。
- ③若为 None 时,
test size
自动设置成 0.25。
**train_size**
:可以为浮点、整数或 None,默认为 None。- ①若为浮点时,表示训练集占总样本的百分比。
- ②若为整数时,表示训练样本的样本数。
- ③若为 None 时,
train_size
自动被设置成 0.75。
**random_state**
:可以为整数、RandomState
实例或None,默认为 None。- ①若为 None 时,每次生成的数据都是随机,可能不一样。
- ②若为整数时,每次生成的数据都相同。
**stratify**
:可以为类似数组或 None。- ①若为 None 时,划分出来的测试集或训练集中,其类标签的比例也是随机的。
- ②若不为 None 时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集。