解决pytorch读取自制数据集出现过的问题


Posted in Python onMay 31, 2021

问题1

问题描述:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

解决方式

数据格式不对, 把image转成tensor,参数transform进行如下设置就可以了:transform=transform.ToTensor()。注意检测一下transform

问题2

问题描述:

TypeError: append() takes exactly one argument (2 given)

出现问题的地方

imgs.append(words[0], int(words[1]))

解决方式

加括号,如下

imgs.append((words[0], int(words[1])))

问题3

问题描述

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

解决方式

数据和模型不在同一设备上,应该要么都在GPU运行,要么都在CPU

问题4

问题描述

RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[1, 3, 512, 512] to have 1 channels, but got 3 channels instead

解决方式

图像竟然是RGB,但我的训练图像是一通道的灰度图,所以得想办法把 mode 转换成灰度图L

补充:神经网络 pytorch 数据集读取(自动读取数据集,手动读取自己的数据)

对于pytorch,我们有现成的包装好的数据集可以使用,也可以自己创建自己的数据集,大致来说有三种方法,这其中用到的两个包是datasets和DataLoader

datasets:用于将数据和标签打包成数据集

DataLoader:用于对数据集的高级处理,比如分组,打乱,处理等,在训练和测试中可以直接使用DataLoader进行处理

第一种 现成的打包数据集

这种比较简答,只需要现成的几行代码和一个路径就可以完成,但是一般都是常用比如cifar-10

解决pytorch读取自制数据集出现过的问题

对于常用数据集,可以使用torchvision.datasets直接进行读取,这是对其常用的处理,该类也是继承于torch.utils.data.Dataset。

#是第一次运行的话会下载数据集 现成的话可以使用root参数指定数据集位置
# 存放的格式如下图
 
# 根据接口读取默认的CIFAR10数据 进行训练和测试
#预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#读取数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
#打包成DataLoader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)
 
#同上
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=1)
classes = (1,2,3,4,5,6,7,8,9,10)  #类别定义
 
#使用
 for epoch in range(3):
        running_loss = 0.0 #清空loss
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
 
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
            
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

解决pytorch读取自制数据集出现过的问题

第二种 自己的图像分类

这也是一个方便的做法,在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。

要求:创建train和test文件夹,每个文件夹下按照类别名字存储图像就可以实现dataloader

这里还是拿上个举例子吧,实际上也可以是我们的数据集

解决pytorch读取自制数据集出现过的问题

每个下面的布局是这样的

解决pytorch读取自制数据集出现过的问题

# 预处理
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
img_data = torchvision.datasets.ImageFolder('data/cifar-10/train/', transform=transform)
data_loader = torch.utils.data.DataLoader(img_data, batch_size=4, shuffle=True, num_workers=1)
 
testset = torchvision.datasets.ImageFolder('data/cifar-10/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 
 for epoch in range(3):
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data #trainloader返回:id,image,labels
            # 将inputs与labels装进Variable中
            inputs, labels = Variable(inputs), Variable(labels)
 
            #使用print代替输出
            print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第三种 一维向量数据集

这个是比较尴尬的,首先我们

假设将数存储到txt等文件中,先把他读取出来,读取的部分就不仔细说了,读到一个列表里就可以

常用的可以是列表等,举例子

trainlist = []  # 保存特征的列表
 
targetpath = 'a/b/b'
filelist = os.listdir(targetpath) #列出文件夹下所有的目录与文件
filecount = len(filelist)
# 根据根路径 读取所有文件名 循环读取文件内容 添加到list
for i in range(filecount):
     filepath = os.path.join(targetpath, filelist[j])
     with open(filepath, 'r') as f:
         line = f.readline()
         # 例如存储格式为 1,2,3,4,5,6 数字之间以逗号隔开
         templist = list(map(int, line.split(',')))
         trainlist.append(templist)
 
# 数据读取完毕 现在为维度为filecount的列表 我们需要转换格式和类型
# 将数据转换为Tensor
 
# 假如我们的两类数据分别存在list0 和 list1中
split = len(list0) # 用于记录标签的分界
 
#使用numpy.array 和 torch.from_numpy 连续将其转换为tensor  使用torch.cat拼接
train0_numpy = numpy.array(list0)
train1_numpy = numpy.array(list1)
train_tensor = torch.cat([torch.from_numpy(train0_numpy), torch.from_numpytrain1_numpy)], 0)
#现在的尺寸是【样本数,长度】 然而在使用神 经网络处理一维数据要求【样本数,维度,长度】
# 这个维度指的像一个图像实际上是一个二维矩阵 但是有三个RGB通道 实际就为【3,行,列】 那么需要处理三个矩阵
# 我们需要在我们的数据中加上这个维度信息
# 注意类型要一样 可以转换
shaper = train_tensor.shape  #获取维度 【样本数,长度】
aa = torch.ones((shaper[0], 1, shaper[1])) # 生成目标矩阵
for i in range(shaper[0]):  # 将所有样本复制到新矩阵
·    aa[i][0][:] = train_tensor[i][:]
train_tensor = aa  # 完成了数据集的转换 【样本数,维度,长度】
 
# 注 意 如果是读取的图像 我们需要的目标维度是【样本数,维度,size_w,size_h】
# 卷积接受的输入是这样的四维度 最后的两个是图像的尺寸 维度表示是通道数量 
  
# 下面是生成标签 标签注意类别之间的分界 split已经在上文计算出来
# 训练标签的
total = len(list0) + len(list1)
train_label = numpy.zeros(total)
train_label[split+1:total] = 1
train_label_tensor = torch.from_numpy(train_label).int()
# print(train_tensor.size(),train_label_tensor.size())
 
# 搭建dataloader完毕
train_dataset = TensorDataset(train_tensor, train_label_tensor)
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
 
for epoch in range(3):
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data #trainloader返回:id,image,labels
        # 将inputs与labels装进Variable中
        inputs, labels = Variable(inputs), Variable(labels)
 
        #使用print代替输出
        print("epoch:", epoch, "的第", i, "个inputs", inputs.data.size(), "labels", labels.data.size())

第四种 保存路径和标签的方式创建数据集

该方法需要略微的麻烦一些,首先你有一个txt,保存了文件名和对应的标签,大概是这个意思

解决pytorch读取自制数据集出现过的问题

然后我们在程序中,根据给定的根目录找到文件,并将标签对应保存

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

这是dataset的原本内容,getitem就是获取元素的部分,用于返回对应index的数据和标签。那么大概需要做的是我们将txt的内容读取进来,使用程序处理标签和数据

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 初始化读取txt 可以设定变换
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
         # 保存列表 其中有图像的数据 和标签
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
    # 返回图像和标签
    
	return img, label
def __len__(self):
	return len(self.imgs)
 
# 当然也可以创建myImageFloder 其txt格式在下图显示 
import os
import torch
import torch.utils.data as data
from PIL import Image 
def default_loader(path):
    return Image.open(path).convert('RGB')
 
class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label) #打开label文件
        c=0
        imgs=[]  # 保存图像的列表
        class_names=[]
        for line in  fh.readlines(): #读取每一行数据
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')] 
            else:
                cls = line.split() #分割为列表
                fn = cls.pop(0)  #弹出最上的一个
                if os.path.isfile(os.path.join(root, fn)):  # 组合路径名 读取图像
                    imgs.append((fn, tuple([float(v) for v in cls])))  #添加到列表
            c=c+1
 
        # 设置信息
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
 
    def __getitem__(self, index):  # 获取图像 给定序号
        fn, label = self.imgs[index]  #读取图像的内容和对应的label
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:  # 是否变换
            img = self.transform(img)
        return img, torch.Tensor(label) # 返回图像和label
 
    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes
#

解决pytorch读取自制数据集出现过的问题

# 而后使用的时候就可以正常的使用
trainset = MyDataset(txt_path=pathFile,transform = None, target_transform = None)
# trainset = torch.utils.data.DataLoader(myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), batch_size= 2, shuffle= False, num_workers= 2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)

它的要点是,继承dataset,在初始化中处理txt文本数据,保存对应的数据,并实现对应的功能。

这其中的原理就是如此,但是注意可能有些许略微不恰当的地方,可能就需要到时候现场调试了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PHP网页抓取之抓取百度贴吧邮箱数据代码分享
Apr 13 Python
Python基于回溯法子集树模板解决0-1背包问题实例
Sep 02 Python
python实现随机梯度下降(SGD)
Mar 24 Python
浅析python的Lambda表达式
Feb 27 Python
Python实现二叉树前序、中序、后序及层次遍历示例代码
May 18 Python
Pycharm新手教程(只需要看这篇就够了)
Jun 18 Python
解决pycharm下os.system执行命令返回有中文乱码的问题
Jul 07 Python
python re.sub()替换正则的匹配内容方法
Jul 22 Python
python pandas 时间日期的处理实现
Jul 30 Python
Python中的 sort 和 sorted的用法与区别
Aug 10 Python
Python3 pywin32模块安装的详细步骤
May 26 Python
Django如何在不停机的情况下创建索引
Aug 02 Python
Python爬虫基础初探selenium
只用40行Python代码就能写出pdf转word小工具
pytorch 如何把图像数据集进行划分成train,test和val
May 31 #Python
Python图片检索之以图搜图
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
python中的plt.cm.Paired用法说明
May 31 #Python
在pycharm中无法import所安装的库解决方案
You might like
thinkphp视图模型查询提示ERR: 1146:Table 'db.pr_order_view' doesn't exist的解决方法
2014/10/30 PHP
php实现购物车功能(上)
2020/07/23 PHP
PHP实现活动人选抽奖功能
2017/04/19 PHP
PHP-FPM和Nginx的通信机制详解
2019/02/01 PHP
Flash+XML滚动新闻代码 无图片 附源码下载
2007/11/22 Javascript
JS 事件绑定函数代码
2010/04/28 Javascript
JQuery UI DatePicker中z-index默认为1的解决办法
2010/09/28 Javascript
jQuery队列控制方法详解queue()/dequeue()/clearQueue()
2010/12/02 Javascript
javascript开发随笔二 动态加载js和文件
2011/11/25 Javascript
仿百度的关键词匹配搜索示例
2013/09/25 Javascript
jQuery 中$(this).index与$.each的使用指南
2014/11/20 Javascript
JS判断字符串包含的方法
2015/05/05 Javascript
JavaScript对HTML DOM使用EventListener进行操作
2015/10/21 Javascript
使用node.js搭建服务器
2017/05/20 Javascript
Bootstrap组件之下拉菜单,多级菜单及按钮布局方法实例
2017/05/25 Javascript
使用JavaScript实现点击循环切换图片效果
2017/09/03 Javascript
详解Node.js一行命令上传本地文件到服务器
2019/04/22 Javascript
json数据格式常见操作示例
2019/06/13 Javascript
在Mac OS上部署Nginx和FastCGI以及Flask框架的教程
2015/05/02 Python
Python学生成绩管理系统简洁版
2020/04/05 Python
使用Python微信库itchat获得好友和群组已撤回的消息
2018/06/24 Python
Python实现曲线拟合操作示例【基于numpy,scipy,matplotlib库】
2018/07/12 Python
python实现桌面壁纸切换功能
2019/01/21 Python
解决python写入带有中文的字符到文件错误的问题
2019/01/31 Python
numpy中生成随机数的几种常用函数(小结)
2020/08/18 Python
html5记忆翻牌游戏实现思路及代码
2013/07/25 HTML / CSS
解决Firefox下不支持outerHTML问题代码分享
2014/06/04 HTML / CSS
中国最大的团购网站:聚划算
2016/09/21 全球购物
以特惠价提供在线奢侈品购物:FRMODA.com
2018/01/25 全球购物
四川internet信息高速公路(C#)笔试题
2012/02/29 面试题
全国文明单位申报材料
2014/05/31 职场文书
党的群众路线整改落实情况汇报
2014/10/28 职场文书
酒店采购员岗位职责
2015/04/03 职场文书
小学班主任工作随笔
2015/08/15 职场文书
2016年读书月活动总结范文
2016/04/06 职场文书
Python 如何利用ffmpeg 处理视频素材
2021/11/27 Python