Pytorch 实现数据集自定义读取


Posted in Python onJanuary 18, 2020

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
  """Assign label indices for Pascal VOC2012 Dataset."""
  idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
  #out = np.empty(idx.shape, dtype = np.int64) 
  out = colormap2label[idx]
  out=out.astype(np.int64)#数据类型转换
  end = time.time()
  return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
  def __init__(self, root, is_train, crop_size=(320,480)):
    self.rgb_mean =(0.485, 0.456, 0.406)
    self.rgb_std = (0.229, 0.224, 0.225)
    self.root=root
    self.crop_size=crop_size
    images = []#创建空列表存文件名称
    txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
    with open(txt_fname, 'r') as f:
      self.images = f.read().split()
    #数据名称整理
    self.files = []
    for name in self.images:
      img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
      label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
      self.files.append({
        "img": img_file,
        "label": label_file,
        "name": name
      })
    self.colormap2label = np.zeros(256**3)
    #整个循环的意思就是将颜色标签映射为单通道的数组索引
    for i, cm in enumerate(VOC_COLORMAP):
      self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
  #按照索引读取每个元素的具体内容
  def __getitem__(self, index):
    
    datafiles = self.files[index]
    name = datafiles["name"]
    image = Image.open(datafiles["img"])
    label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
    #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
    imgCenterCrop = transforms.Compose([
       transforms.CenterCrop(self.crop_size),
       transforms.ToTensor(),
       transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
     ])
    labelCenterCrop = transforms.CenterCrop(self.crop_size)
    cropImage=imgCenterCrop(image)
    croplabel=labelCenterCrop(label)
    croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
    
    #将颜色标签图转为序号标签图
    mylabel=voc_label_indices(croplabel, self.colormap2label)
    
    return cropImage,mylabel
  #返回图像数据长度
  def __len__(self):
    return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
        [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
        [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
        [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
        [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
        [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
  getimgs, labels= data
  img = transforms.ToPILImage()(getimgs[0])

  labels = labels.numpy()#tensor转numpy
  labels=labels[0]#获得批次标签集中的一张标签图像
  labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

  ##将单通道索引标签图片映射回颜色标签图片
  newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
  for i in range(0, 480):
    for j in range(0, 320):
      sele=labels[i][j]#取得坐标点对应像素的值
      newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

  #显示图像和标签
  plt.figure("image")
  ax1 = plt.subplot(1,2,1)
  ax2 = plt.subplot(1,2,2)
  plt.sca(ax1)
  plt.imshow(img)
  plt.sca(ax2)
  plt.imshow(newIm)
  plt.show()

以上这篇Pytorch 实现数据集自定义读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现嵌套列表去重方法示例
Dec 28 Python
python实现数据导出到excel的示例--普通格式
May 03 Python
pygame游戏之旅 计算游戏中躲过的障碍数量
Nov 20 Python
Python-while 计算100以内奇数和的方法
Jun 11 Python
Django rest framework jwt的使用方法详解
Aug 08 Python
py-charm延长试用期限实例
Dec 22 Python
python异常处理之try finally不报错的原因
May 18 Python
Python模拟伯努利试验和二项分布代码实例
May 27 Python
Django中F函数的使用示例代码详解
Jul 06 Python
pandas to_excel 添加颜色操作
Jul 14 Python
深入浅析pycharm中 Make available to all projects的含义
Sep 15 Python
Python字符串的转义字符
Apr 07 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
Jan 18 #Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
You might like
调频问题解答
2021/03/01 无线电
php对mongodb的扩展(初出茅庐)
2012/11/11 PHP
ThinkPHP CURD方法之data方法详解
2014/06/18 PHP
ThinkPHP页面跳转success与error方法概述
2014/06/25 PHP
php实现无限级分类(递归方法)
2015/08/06 PHP
joomla数据库操作示例代码
2016/01/06 PHP
用 Javascript 验证表单(form)中的单选(radio)值
2009/09/08 Javascript
window.ActiveXObject使用说明
2010/11/08 Javascript
Jquery绑定事件(bind和live的区别介绍)
2013/08/23 Javascript
jquery cookie的用法总结
2013/11/18 Javascript
JS两种定义方式的区别、内部原理
2013/11/21 Javascript
jquery 选取方法都有哪些
2014/05/18 Javascript
通过设置CSS中的position属性来固定层的位置
2015/12/14 Javascript
Html5 js实现手风琴效果
2020/04/17 Javascript
bootstrap侧边栏圆点导航
2017/01/11 Javascript
js Date()日期函数浏览器兼容问题解决方法
2017/09/12 Javascript
Vue插件从封装到发布的完整步骤记录
2019/02/28 Javascript
详解vue-cli3 中跨域解决方案
2019/04/10 Javascript
AngularJs中$cookies简单用法分析
2019/05/30 Javascript
[51:20]完美世界DOTA2联赛PWL S2 Magma vs PXG 第一场 11.28
2020/12/01 DOTA
python实现二分查找算法
2017/09/21 Python
python XlsxWriter模块创建aexcel表格的实例讲解
2018/05/03 Python
python中virtualenvwrapper安装与使用
2018/05/20 Python
Python使用正则表达式分割字符串的实现方法
2019/07/16 Python
python 修改本地网络配置的方法
2019/08/14 Python
python使用PIL和matplotlib获取图片像素点并合并解析
2019/09/10 Python
Django将默认的SQLite更换为MySQL的实现
2019/11/18 Python
Python pymysql模块安装并操作过程解析
2020/10/13 Python
基于Python的接口自动化unittest测试框架和ddt数据驱动详解
2021/01/27 Python
中东地区最大的奢侈品市场:The Luxury Closet
2019/04/09 全球购物
英语专业应届生求职信范文
2013/11/15 职场文书
质量提升方案
2014/06/16 职场文书
郭明义电影观后感
2015/06/08 职场文书
2016大学生党校学习心得体会
2016/01/06 职场文书
文案策划岗位个人自我评价(范文)
2019/08/08 职场文书
Navicat for MySQL的使用教程详解
2021/05/27 MySQL