pytorch加载自己的图像数据集实例


Posted in Python onJuly 07, 2020

之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码。到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题。

参考文章https://3water.com/article/177613.htm

下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor。最后读取第一张图片并显示。

# 数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

transform = transforms.Compose([
 transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
 # transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])

#定义自己的数据集合
class FlameSet(data.Dataset):
 def __init__(self,root):
  # 所有图片的绝对路径
  imgs=os.listdir(root)
  self.imgs=[os.path.join(root,k) for k in imgs]
  self.transforms=transform

 def __getitem__(self, index):
  img_path = self.imgs[index]
  pil_img = Image.open(img_path)
  if self.transforms:
   data = self.transforms(pil_img)
  else:
   pil_img = np.asarray(pil_img)
   data = torch.from_numpy(pil_img)
  return data

 def __len__(self):
  return len(self.imgs)

if __name__ == '__main__':
 dataSet=FlameSet('./test')
 print(dataSet[0])

显示结果:

pytorch加载自己的图像数据集实例

补充知识:使用Pytorch进行读取本地的MINIST数据集并进行装载

pytorch中的torchvision.datasets中自带MINIST数据集,可直接调用模块进行获取,也可以进行自定义自己的Dataset类进行读取本地数据和初始化数据。

1. 直接使用pytorch自带的MNIST进行下载:

缺点: 下载速度较慢,而且如果中途下载失败一般得是重新进行执行代码进行下载:

# # 训练数据和测试数据的下载
# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True # 瞎子啊过程中如果中断,或者下载完成之后再次运行,则会出现报错
)

testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

2. 自定义dataset类进行数据的读取以及初始化。

其中自己下载的MINIST数据集的内容如下:

pytorch加载自己的图像数据集实例

自己定义的dataset类需要继承: Dataset

需要实现必要的魔法方法:

__init__魔法方法里面进行读取数据文件

__getitem__魔法方法进行支持下标访问

__len__魔法方法返回自定义数据集的大小,方便后期遍历

示例如下:

class DealDataset(Dataset):
  """
    读取数据、初始化数据
  """
  def __init__(self, folder, data_name, label_name,transform=None):
    (train_set, train_labels) = load_minist_data.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
    self.train_set = train_set
    self.train_labels = train_labels
    self.transform = transform

  def __getitem__(self, index):

    img, target = self.train_set[index], int(self.train_labels[index])
    if self.transform is not None:
      img = self.transform(img)
    return img, target

  def __len__(self):
    return len(self.train_set)

其中load_minist_data.load_data也是我们自己写的读取数据文件的函数,即放在了load_minist_data.py中的load_data函数中。具体实现如下:

def load_data(data_folder, data_name, label_name):
 """
    data_folder: 文件目录
    data_name: 数据文件名
    label_name:标签数据文件名
  """
 with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
  y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

 with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
  x_train = np.frombuffer(
    imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
 return (x_train, y_train)

编写完自定义的dataset就可以进行实例化该类并装载数据:

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset('MNIST_data/', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
testDataset = DealDataset('MNIST_data/', "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = dataloader.DataLoader(
  dataset=trainDataset,
  batch_size=100, # 一个批次可以认为是一个包,每个包中含有100张图片
  shuffle=False,
)

test_loader = dataloader.DataLoader(
  dataset=testDataset,
  batch_size=100,
  shuffle=False,
)

构建简单的神经网络并进行训练和测试:

class NeuralNet(nn.Module):

  def __init__(self, input_num, hidden_num, output_num):
    super(NeuralNet, self).__init__()
    self.fc1 = nn.Linear(input_num, hidden_num)
    self.fc2 = nn.Linear(hidden_num, output_num)
    self.relu = nn.ReLU()

  def forward(self,x):
    x = self.fc1(x)
    x = self.relu(x)
    y = self.fc2(x)
    return y

# 参数初始化
epoches = 5
lr = 0.001
input_num = 784
hidden_num = 500
output_num = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 产生训练模型对象以及定义损失函数和优化函数
model = NeuralNet(input_num, hidden_num, output_num)
model.to(device)
criterion = nn.CrossEntropyLoss() # 使用交叉熵作为损失函数
optimizer = optim.Adam(model.parameters(), lr=lr)

# 开始循环训练
for epoch in range(epoches): # 一个epoch可以认为是一次训练循环
  for i, data in enumerate(train_loader):
    (images, labels) = data
    images = images.reshape(-1, 28*28).to(device)
    labels = labels.to(device)
    output = model(images) # 经过模型对象就产生了输出
    loss = criterion(output, labels.long()) # 传入的参数: 输出值(预测值), 实际值(标签)
    optimizer.zero_grad() # 梯度清零
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0: # i表示样本的编号
      print('Epoch [{}/{}], Loss: {:.4f}'
         .format(epoch + 1, epoches, loss.item())) # {}里面是后面需要传入的变量
                              # loss.item
# 开始测试
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, 28*28).to(device) # 此处的-1一般是指自动匹配的意思, 即不知道有多少行,但是确定了列数为28 * 28
                           # 其实由于此处28 * 28本身就已经等于了原tensor的大小,所以,行数也就确定了,为1
    labels = labels.to(device)
    output = model(images)
    _, predicted = torch.max(output, 1)
    total += labels.size(0) # 此处的size()类似numpy的shape: np.shape(train_images)[0]
    correct += (predicted == labels).sum().item()
  print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

以上这篇pytorch加载自己的图像数据集实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python selenium UI自动化解决验证码的4种方法
Jan 05 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 Python
python将pandas datarame保存为txt文件的实例
Feb 12 Python
django 邮件发送模块smtp使用详解
Jul 22 Python
Django中的用户身份验证示例详解
Aug 07 Python
PyTorch中常用的激活函数的方法示例
Aug 20 Python
pycharm运行scrapy过程图解
Nov 22 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
PyQt5中向单元格添加控件的方法示例
Mar 24 Python
python字符串的index和find的区别详解
Jun 20 Python
Python如何进行时间处理
Aug 06 Python
如何向scrapy中的spider传递参数的几种方法
Nov 18 Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
Scrapy模拟登录赶集网的实现代码
Jul 07 #Python
You might like
用IE远程创建Mysql数据库的简易程序
2006/10/09 PHP
让PHP开发者事半功倍的十大技巧小结
2010/04/20 PHP
php引用和拷贝的区别知识点总结
2019/09/23 PHP
获取Javscript执行函数名称的方法
2006/12/22 Javascript
[原创]后缀就扩展名为js的文件是什么文件
2007/12/06 Javascript
DB.ASP 用Javascript写ASP很灵活很好用很easy
2011/07/31 Javascript
ASP.NET jQuery 实例17 通过使用jQuery validation插件校验ListBox
2012/02/03 Javascript
拖动table标题实现改变td的大小(css+js代码)
2013/04/16 Javascript
一段非常简单的js判断浏览器的内核
2014/08/17 Javascript
node.js中的fs.mkdir方法使用说明
2014/12/17 Javascript
JS实现兼容性好,自动置顶的淘宝悬浮工具栏效果
2015/09/18 Javascript
JavaScript SweetAlert插件实现超酷消息警告框
2016/01/28 Javascript
js实现各种复制到剪贴板的方法(分享)
2016/10/27 Javascript
关于Bootstrap按钮组件消除黄框的方法
2017/05/19 Javascript
nodejs中方法和模块用法示例
2018/12/24 NodeJs
Vue自定义render统一项目组弹框功能
2020/06/07 Javascript
three.js 如何制作魔方
2020/07/31 Javascript
vue监听浏览器原生返回按钮,进行路由转跳操作
2020/09/09 Javascript
[00:03]DOTA2新版本PA至宝展示
2014/11/19 DOTA
[43:35]EG vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
[43:41]VP vs RNG 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.21.mp4
2020/07/19 DOTA
python 获取et和excel的版本号
2009/04/09 Python
Python常见字符串操作函数小结【split()、join()、strip()】
2018/02/02 Python
opencv实现图片模糊和锐化操作
2018/11/19 Python
python输出pdf文档的实例
2020/02/13 Python
Python基于smtplib协议实现发送邮件
2020/06/03 Python
Python 利用OpenCV给照片换底色的示例代码
2020/08/03 Python
基于IE10/HTML5 开发
2013/04/22 HTML / CSS
优衣库澳大利亚官网:UNIQLO澳大利亚
2017/01/18 全球购物
Sunglasses Shop瑞典:欧洲领先的太阳镜网上商店
2018/04/22 全球购物
智能电子秤、手表和健康监测仪:Withings(之前为诺基亚健康)
2018/10/30 全球购物
银行自荐信范文
2013/10/07 职场文书
《歌唱二小放牛郎》教学反思
2014/04/19 职场文书
2015-2016年小学教导工作总结
2015/07/21 职场文书
幼儿园教学反思范文
2016/03/02 职场文书
oracle删除超过N天数据脚本的方法
2022/02/28 Oracle