pytorch 如何把图像数据集进行划分成train,test和val


Posted in Python onMay 31, 2021

1、手上目前拥有数据集是一大坨,没有train,test,val的划分

如图所示


pytorch 如何把图像数据集进行划分成train,test和val

2、目录结构:

|---data
     |---dslr
         |---images
         		|---back_pack
         			|---a.jpg
         			|---b.jpg
         			...

3、转换后的格式如图

pytorch 如何把图像数据集进行划分成train,test和val

目录结构为:

|---datanews
     |---dslr
         |---images
         		|---test
         		|---train
         		|---valid
	         		|---back_pack
	         			|---a.jpg
	         			|---b.jpg
	         			...

4、代码如下:

4.1 先创建同样结构的层级结构

4.2 然后讲原始数据按照比例划分

4.3 移入到对应的文件目录里面

import os, random, shutil

def make_dir(source, target):
    '''
    创建和源文件相似的文件路径函数
    :param source: 源文件位置
    :param target: 目标文件位置
    '''
    dir_names = os.listdir(source)
    for names in dir_names:
        for i in ['train', 'valid', 'test']:
            path = target + '/' + i + '/' + names
            if not os.path.exists(path):
                os.makedirs(path)

def divideTrainValiTest(source, target):
    '''
        创建和源文件相似的文件路径
        :param source: 源文件位置
        :param target: 目标文件位置
    '''
    # 得到源文件下的种类
    pic_name = os.listdir(source)
    
    # 对于每一类里的数据进行操作
    for classes in pic_name:
        # 得到这一种类的图片的名字
        pic_classes_name = os.listdir(os.path.join(source, classes))
        random.shuffle(pic_classes_name)
        
        # 按照8:1:1比例划分
        train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
        valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
        test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
        
        # 对于每个图片,移入到对应的文件夹里面
        for train_pic in train_list:
            shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
        for validation_pic in valid_list:
            shutil.copyfile(source + '/' + classes + '/' + validation_pic,
                            target + '/valid/' + classes + '/' + validation_pic)
        for test_pic in test_list:
            shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)

if __name__ == '__main__':
    filepath = r'../data/dslr/images'
    dist = r'../datanews/dslr/images'
    make_dir(filepath, dist)
    divideTrainValiTest(filepath, dist)

补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因

在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法

该方法使用如下:

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0)  # 训练集和验证集使用9:1
 
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
 
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
 
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)  
 
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误

from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
 
traindata = np.load(train_path)   # image_num * W * H
trainlabel = np.load(train_label_path)
 
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
 
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1)  # 训练集和验证集使用9:1

报错原因:

train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3 发送任意文件邮件的实例
Jan 23 Python
Python文件常见操作实例分析【读写、遍历】
Dec 10 Python
python将txt文件读取为字典的示例
Dec 22 Python
Python使用sklearn实现的各种回归算法示例
Jul 04 Python
浅谈PyQt5中异步刷新UI和Python多线程总结
Dec 13 Python
Python 解码Base64 得到码流格式文本实例
Jan 09 Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 Python
TFRecord格式存储数据与队列读取实例
Jan 21 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
Feb 26 Python
django实现HttpResponse返回json数据为中文
Mar 27 Python
Python实现捕获异常发生的文件和具体行数
Apr 25 Python
python字符串的多行输出的实例详解
Jun 08 Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
如何在pycharm中快捷安装pip命令(如pygame)
Python 实现绘制子图及子图刻度的变换等问题
python 利用PyAutoGUI快速构建自动化操作脚本
You might like
用Zend Encode编写开发PHP程序
2010/02/21 PHP
PHP 获取远程文件内容的函数代码
2010/03/24 PHP
php中chdir()函数用法实例
2014/11/13 PHP
关于 Laravel Redis 多个进程同时取队列问题详解
2017/12/25 PHP
PHP7 echo和print语句实例用法
2019/02/15 PHP
纯js实现的论坛常用的运行代码的效果
2008/07/15 Javascript
[原创]推荐10款最热门jQuery UI框架
2014/08/19 Javascript
js点击任意区域弹出层消失实现代码
2016/12/27 Javascript
深入理解Javascript中的valueOf与toString
2017/01/04 Javascript
Vue.js基础指令实例讲解(各种数据绑定、表单渲染大总结)
2017/07/03 Javascript
AngularJS实现tab选项卡的方法详解
2017/07/05 Javascript
详解JSONObject和JSONArray区别及基本用法
2017/10/25 Javascript
浅析Vue.js中v-bind v-model的使用和区别
2018/12/04 Javascript
Vue项目实现换肤功能的一种方案分析
2019/08/28 Javascript
Vuex modules模式下mapState/mapMutations的操作实例
2019/10/17 Javascript
解决vue项目本地启动时无法携带cookie的问题
2021/02/06 Vue.js
python bmp转换为jpg 并删除原图的方法
2018/10/25 Python
python 内置模块详解
2019/01/01 Python
Python实现FM算法解析
2019/06/18 Python
在tensorflow中实现去除不足一个batch的数据
2020/01/20 Python
Python pyautogui模块实现鼠标键盘自动化方法详解
2020/02/17 Python
浅谈python元素如何去重,去重后如何保持原来元素的顺序不变
2020/02/28 Python
让Django的BooleanField支持字符串形式的输入方式
2020/05/20 Python
python3.7调试的实例方法
2020/07/21 Python
python 19个值得学习的编程技巧
2020/08/15 Python
CSS3实现闪烁动画效果的方法
2015/02/09 HTML / CSS
HTML5触摸事件(touchstart、touchmove和touchend)的实现
2020/05/08 HTML / CSS
Photobook澳大利亚:制作相片书,婚礼卡,旅行相簿
2017/01/12 全球购物
采用怎样的方法保证数据的完整性
2013/12/02 面试题
财务会计专业毕业生自荐信
2013/10/19 职场文书
作风年建设汇报材料
2014/08/14 职场文书
影视广告专业求职信
2014/09/02 职场文书
学期个人工作总结
2015/02/13 职场文书
2015年暑期见闻
2015/07/14 职场文书
2016年十一促销广告语
2016/01/28 职场文书
Python实现打乒乓小游戏
2021/09/25 Python