Pytorch自己加载单通道图片用作数据集训练的实例


Posted in Python onJanuary 18, 2020

pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100。在torchvision的dataset包里面,用的时候直接调用就行了。具体的调用格式可以去看文档(目前好像只有英文的)。网上也有很多源代码。

不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了。pytorch在torchvision.dataset包里面封装过一个函数ImageFolder()。这个函数功能很强大,只要你直接将数据集路径保存为例如“train/1/1.jpg ,rain/1/2.jpg …… ”就可以根据根目录“./train”将数据集装载了。

dataset.ImageFolder(root="datapath", transfroms.ToTensor())

但是后来我发现一个问题,就是这个函数加载出来的图像矩阵都是三通道的,并且没有什么参数调用可以让其变为单通道。如果我们要用到单通道数据集(灰度图)的话,比如自己加载Lenet-5模型的数据集,就只能自己写numpy数组再转为pytorch的Tensor()张量了。

接下来是我做的过程:

首先,还是要用到opencv,用灰度图打开一张图片,省事。

#读取图片 这里是灰度图 
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)

其次,pytorch有自己的numpy转Tensor函数,直接转就行了。

data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)

下一步利用torch.util和torchvision里面的dataLoader函数,就能直接得到和torchvision.dataset里面封装好的包相同的数据集样本了

dataset = dataf.TensorDataset(data_x,data_y)
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)

最后就是自己建网络设计参数训练了,这部分和文档以及github中的差不多,就不赘述了。

下面是整个程序的源代码,我利用的还是上次的车标识别的数据集,一共分四类,用的是2层卷积核两层全连接。

源代码:

# coding=utf-8
import os
import cv2
import numpy as np
import random
 
import torch
import torch.nn as nn
import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
 
#训练参数
cuda = False
train_epoch = 20
train_lr = 0.01
train_momentum = 0.5
batchsize = 5
 
 
#测试训练集路径
test_path = "/home/test/"
train_path = "/home/train/"
 
#路径数据
all_path =[]
 
def load_data(data_path):
 signal = os.listdir(data_path)
 for fsingal in signal: 
  filepath = data_path+fsingal
  filename = os.listdir(filepath)
  for fname in filename:
   ffpath = filepath+"/"+fname
   path = [fsingal,ffpath]
   all_path.append(path)
   
#设立数据集多大
 count = len(all_path)
 data_x = np.empty((count,1,28,28),dtype="float32")
 data_y = []
#打乱顺序
 random.shuffle(all_path)
 i=0;
 
#读取图片 这里是灰度图 最后结果是i*i*i*i
#分别表示:batch大小 , 通道数, 像素矩阵
 for item in all_path:
  img = cv2.imread(item[1],0)
  img = cv2.resize(img,(28,28))
  arr = np.asarray(img,dtype="float32")
  data_x[i ,:,:,:] = arr
  i+=1
  data_y.append(int(item[0]))
  
 data_x = data_x / 255
 data_y = np.asarray(data_y)
#  lener = len(all_path)
 data_x = torch.from_numpy(data_x)
 data_y = torch.from_numpy(data_y)
 dataset = dataf.TensorDataset(data_x,data_y)
 
 loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
  
 return loader
#  print data_y
 
 
 
train_load = load_data(train_path)
test_load = load_data(test_path)
 
class L5_NET(nn.Module):
 def __init__(self):
  super(L5_NET ,self).__init__();
  #第一层输入1,20个卷积核 每个5*5
  self.conv1 = nn.Conv2d(1 , 20 , kernel_size=5)
  #第二层输入20,30个卷积核 每个5*5
  self.conv2 = nn.Conv2d(20 , 30 , kernel_size=5)
  #drop函数
  self.conv2_drop = nn.Dropout2d()
  #全链接层1,展开30*4*4,连接层50个神经元
  self.fc1 = nn.Linear(30*4*4,50)
  #全链接层1,50-4 ,4为最后的输出分类
  self.fc2 = nn.Linear(50,4)
 
 #前向传播
 def forward(self,x):
  #池化层1 对于第一层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv1(x)  ,2 ) )
  #池化层2 对于第二层卷积池化,池化核2*2
  x = F.relu(F.max_pool2d( self.conv2_drop( self.conv2(x) ) , 2 ) )
  #平铺轴30*4*4个神经元
  x = x.view(-1 , 30*4*4)
  #全链接1
  x = F.relu( self.fc1(x) )
  #dropout链接
  x = F.dropout(x , training= self.training)
  #全链接w
  x = self.fc2(x)
  #softmax链接返回结果
  return F.log_softmax(x)
 
model = L5_NET()
if cuda :
 model.cuda()
  
 
optimizer = optim.SGD(model.parameters()  , lr =train_lr , momentum = train_momentum )
 
#预测函数
def train(epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_load):
  if cuda:
   data, target = data.cuda(), target.cuda()
  data, target = Variable(data), Variable(target)
  #求导
  optimizer.zero_grad()
  #训练模型,输出结果
  output = model(data)
  #在数据集上预测loss
  loss = F.nll_loss(output, target)
  #反向传播调整参数pytorch直接可以用loss
  loss.backward()
  #SGD刷新进步
  optimizer.step()
  #实时输出
  if batch_idx % 10 == 0:
   print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    epoch, batch_idx * len(data), len(train_load.dataset),
    100. * batch_idx / len(train_load), loss.data[0]))
#    
   
#测试函数
def test(epoch):
 model.eval()
 test_loss = 0
 correct = 0
 for data, target in test_load:
  
  if cuda:
   data, target = data.cuda(), target.cuda()
   
  data, target = Variable(data, volatile=True), Variable(target)
  #在测试集上预测
  output = model(data)
  #计算在测试集上的loss
  test_loss += F.nll_loss(output, target).data[0]
  #获得预测的结果
  pred = output.data.max(1)[1] # get the index of the max log-probability
  #如果正确,correct+1
  correct += pred.eq(target.data).cpu().sum()
 
 #loss计算
 test_loss = test_loss
 test_loss /= len(test_load)
 #输出结果
 print('\nThe {} epoch result : Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
  epoch,test_loss, correct, len(test_load.dataset),
  100. * correct / len(test_load.dataset)))
 
for epoch in range(1, train_epoch+ 1):
 train(epoch)
 test(epoch)

最后的训练结果和在keras下差不多,不过我训练的时候好像把训练集和测试集弄反了,数目好像测试集比训练集还多,有点尴尬,不过无伤大雅。结果图如下:

Pytorch自己加载单通道图片用作数据集训练的实例

以上这篇Pytorch自己加载单通道图片用作数据集训练的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现全角半角转换的方法
Aug 18 Python
Python中使用select模块实现非阻塞的IO
Feb 03 Python
python opencv之分水岭算法示例
Feb 24 Python
Python生成MD5值的两种方法实例分析
Apr 26 Python
解决Django layui {{}}冲突的问题
Aug 29 Python
Python监控服务器实用工具psutil使用解析
Dec 19 Python
Java Spring项目国际化(i18n)详细方法与实例
Mar 20 Python
Linux安装Python3如何和系统自带的Python2并存
Jul 23 Python
python绘制分布折线图的示例
Sep 24 Python
python 实现全球IP归属地查询工具
Dec 18 Python
基于PyTorch中view的用法说明
Mar 03 Python
python实现批量移动文件
Apr 05 Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
Python中实现输入超时及如何通过变量获取变量名
Jan 18 #Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 #Python
python:目标检测模型预测准确度计算方式(基于IoU)
Jan 18 #Python
You might like
centos 7.2下搭建LNMP环境教程
2016/11/20 PHP
浅谈php中fopen不能创建中文文件名文件的问题
2017/02/06 PHP
PHP实现微信小程序用户授权的工具类示例
2019/03/05 PHP
准确获得页面、窗口高度及宽度的JS
2006/11/26 Javascript
关于JS判断图片是否加载完成且获取图片宽度的方法
2013/04/09 Javascript
jquery easyui滚动条部分设置介绍
2013/09/12 Javascript
JS禁用浏览器退格键实现思路及代码
2013/10/29 Javascript
javascript与jquery中跳出循环的区别总结
2013/11/04 Javascript
直接在JS里创建JSON数据然后遍历使用
2014/07/25 Javascript
JavaScript sub方法入门实例(把字符串显示为下标)
2014/10/17 Javascript
js超时调用setTimeout和间歇调用setInterval实例分析
2015/01/28 Javascript
jQuery选择器源码解读(二):select方法
2015/03/31 Javascript
举例详解Python中smtplib模块处理电子邮件的使用
2015/06/24 Javascript
JavaScript返回上一页的三种方法及区别介绍
2015/07/04 Javascript
js实现复选框的全选和取消全选效果
2017/01/03 Javascript
详解win7 cmd执行vue不是内部命令的解决方法
2017/07/27 Javascript
jQuery实现页码跳转式动态数据分页
2017/12/31 jQuery
jQuery实现表单动态加减、ajax表单提交功能
2018/06/08 jQuery
Python 反转字符串(reverse)的方法小结
2018/02/20 Python
Python实现针对给定单链表删除指定节点的方法
2018/04/12 Python
Django 使用logging打印日志的实例
2018/04/28 Python
Python多线程编程之多线程加锁操作示例
2018/09/06 Python
python调用接口的4种方式代码实例
2019/11/19 Python
Python数据分析pandas模块用法实例详解
2019/11/20 Python
linux 下python多线程递归复制文件夹及文件夹中的文件
2020/01/02 Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
2020/05/25 Python
Python操作MySQL数据库的示例代码
2020/07/13 Python
python 6种方法实现单例模式
2020/12/15 Python
HTML5之SVG 2D入门10—滤镜的定义及使用
2013/01/30 HTML / CSS
amazeui树节点自动展开折叠面板并选中第一个树节点的实现
2020/08/24 HTML / CSS
《祁黄羊》教学反思
2014/04/22 职场文书
2014年学校总务处工作总结
2014/12/08 职场文书
幼儿园老师个人总结
2015/02/28 职场文书
区域销售经理岗位职责
2015/04/02 职场文书
Oracle 11g数据库使用expdp每周进行数据备份并上传到备份服务器
2022/06/28 Oracle
js 实现Material UI点击涟漪效果示例
2022/09/23 Javascript