参考来源:
CSDN:【pytorch】使用 torch.utils.data.random_split() 划分数据集
pytorch 文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split
1. torch.utils.data.random_split()
torch.utils.data.random_split(dataset,lengths,generator=<torch._C.Generator object>)
功能:
随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
参数:
**dataset**(Dataset) – 要划分的数据集。**lengths**(sequence) – 要划分的长度。**generator**(Generator) – 用于随机排列的生成器。
示例:
import torchfrom torch.utils.data import random_splitdataset = range(10)train_dataset, test_dataset = random_split(dataset=dataset,lengths=[7, 3],generator=torch.Generator().manual_seed(0))print(list(train_dataset))print(list(test_dataset))"""output:[4, 1, 7, 5, 3, 9, 0][8, 6, 2]"""
**torch.Generator().manual_seed(0)** 和 **torch.manual_seed(0)** 的效果相同,我们验证一下。
import torchfrom torch.utils.data import random_splitdataset = range(10)torch.manual_seed(0)train_dataset, test_dataset = random_split(dataset=dataset,lengths=[7, 3])print(list(train_dataset))print(list(test_dataset))"""output:[4, 1, 7, 5, 3, 9, 0][8, 6, 2]"""
2. sklearn.model_selection 中 train_test_split() 函数
博客园:sklearn 的 train_test_split() 各函数参数含义解释(非常全)
CSDN:sklearn.model_selection 中 train_test_split() 函数
在机器学习中,我们通常将原始数据按照比例分割为“测试集”和“训练集”,从 sklearn.model_selection 中调用 train_test_split() 函数。train_test_split() 是 sklearn.model_selection 中的分离器函数,用于将数组或矩阵划分为训练集和测试集,函数样式为:
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(train_data,train_target,test_size=0.4,random_state=0,stratify=y_train,shuffle)
参数解释:
**train_data**:待划分的样本数据集**train_target**:待划分的对应样本数据的样本标签**test_size**:样本占比,如果是整数的话就是样本的数量。1)浮点数,在0 ~ 1之间,表示样本占比(test_size = 0.3,则样本数据中有30%的数据作为测试数据,记入X_test,其余70%数据记入X_train,同时适用于样本标签);2)整数,表示样本数据中有多少数据记入X_test中,其余数据记入X_train。**random_state**:随机数种子,种子不同,每次采的样本不一样;种子相同,采的样本不变(random_state不取,采样数据不同,但random_state等于某个值,采样数据相同,取 0 的时候也相同,这可以自己编程尝试下,不过想改变数值也可以设置random_state = int(time.time()))**stratify**:是为了保持split前类的分布。比如有100个数据,80个属于A类,20个属于B类。如果train_test_split(... test_size=0.25, stratify = y_all), 那么split之后数据如下:training:75个数据,其中60个属于A类,15个属于B类。testing:25个数据,其中20个属于A类,5个属于B类。- 用了
stratify参数,training集和testing集的类的比例是A:B= 4:1,等同于split前的比例(80:20)。通常在这种类分布不平衡的情况下会用到stratify。 - 将
stratify=X就是按照X中的比例分配 - 将
stratify=y就是按照y中的比例分配
**shuffle**:洗牌模式,1)shuffle = False,不打乱样本数据顺序;2)shuffle = True,打乱样本数据顺序。
整体总结起来各个参数的设置及其类型如下:
主要参数说明:
- *arrays:可以是列表、numpy数组、scipy稀疏矩阵或pandas的数据框
- test_size:可以为浮点、整数或 None,默认为 None。
- ①若为浮点时,表示测试集占总样本的百分比。
- ②若为整数时,表示测试样本样本数。
- ③若为 None 时,
test size自动设置成 0.25。
**train_size**:可以为浮点、整数或 None,默认为 None。- ①若为浮点时,表示训练集占总样本的百分比。
- ②若为整数时,表示训练样本的样本数。
- ③若为 None 时,
train_size自动被设置成 0.75。
**random_state**:可以为整数、RandomState实例或None,默认为 None。- ①若为 None 时,每次生成的数据都是随机,可能不一样。
- ②若为整数时,每次生成的数据都相同。
**stratify**:可以为类似数组或 None。- ①若为 None 时,划分出来的测试集或训练集中,其类标签的比例也是随机的。
- ②若不为 None 时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集。
