pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程之序列操作实例详解
Jul 22 Python
python使用os.listdir和os.walk获得文件的路径的方法
Dec 16 Python
我就是这样学习Python中的列表
Jun 02 Python
Python多进程入门、分布式进程数据共享实例详解
Jun 03 Python
Python格式化字符串f-string概览(小结)
Jun 18 Python
Selenium+Python 自动化操控登录界面实例(有简单验证码图片校验)
Jun 28 Python
Pandas分组与排序的实现
Jul 23 Python
python列表推导式入门学习解析
Dec 02 Python
使用PyTorch训练一个图像分类器实例
Jan 08 Python
通过实例了解Python str()和repr()的区别
Jan 17 Python
Python输出指定字符串的方法
Feb 06 Python
python实现可下载音乐的音乐播放器
Feb 25 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
ThinkPHP模板中判断volist循环的最后一条记录的验证方法
2014/07/01 PHP
php中header设置常见文件类型的content-type
2015/06/23 PHP
CI(CodeIgniter)框架实现图片上传的方法
2017/03/24 PHP
php实现大文件断点续传下载实例代码
2019/10/01 PHP
使用Modello编写JavaScript类
2006/12/22 Javascript
cnblogs中在闪存中屏蔽某人的实现代码
2010/11/14 Javascript
javascript for循环从入门到偏门(效率优化+奇特用法)
2012/08/01 Javascript
Javascript模块化编程(一)AMD规范(规范使用模块)
2013/01/17 Javascript
JavaScript实现动画打开半透明提示层的方法
2015/04/21 Javascript
JS模拟键盘打字效果的方法
2015/08/05 Javascript
jQuery基于ID调用指定iframe页面内的方法
2016/07/06 Javascript
js修改onclick动作的四种方法(推荐)
2016/08/18 Javascript
EsLint入门学习教程
2017/02/17 Javascript
详解vue 模拟后台数据(加载本地json文件)调试
2017/08/25 Javascript
nodejs用gulp管理前端文件方法
2018/06/24 NodeJs
koa socket即时通讯的示例代码
2018/09/07 Javascript
微信上传视频文件提示(推荐)
2018/11/22 Javascript
js防抖和节流的深入讲解
2018/12/06 Javascript
Jquery异步上传文件代码实例
2019/11/13 jQuery
Python中的特殊语法:filter、map、reduce、lambda介绍
2015/04/14 Python
python传递参数方式小结
2015/04/17 Python
python机器学习库常用汇总
2017/11/15 Python
Centos 升级到python3后pip 无法使用的解决方法
2018/06/12 Python
PythonWeb项目Django部署在Ubuntu18.04腾讯云主机上
2019/04/01 Python
python装饰器代替set get方法实例
2019/12/19 Python
python 使用多线程创建一个Buffer缓存器的实现思路
2020/07/02 Python
pandas apply多线程实现代码
2020/08/17 Python
python自动化测试三部曲之request+django实现接口测试
2020/10/07 Python
无畏的旅行:Intrepid Travel
2017/12/20 全球购物
阿玛瑞酒店中文官方网站:Amari.com
2018/02/13 全球购物
Microsoft Advertising美国:微软搜索广告
2019/05/01 全球购物
加拿大在线眼镜零售商:SmartBuyGlasses加拿大
2019/05/25 全球购物
机电工程专业应届生求职信
2013/10/03 职场文书
婚前财产公证书
2014/04/10 职场文书
建筑工程挂靠协议书
2016/03/23 职场文书
高性能跳频抗干扰宽带自组网电台
2022/02/18 无线电