pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的hashlib和base64加密模块使用实例
Sep 02 Python
python如何在列表、字典中筛选数据
Mar 19 Python
TensorFlow实现卷积神经网络
May 24 Python
对python 自定义协议的方法详解
Feb 13 Python
Python异步操作MySQL示例【使用aiomysql】
May 16 Python
Python3 使用pillow库生成随机验证码
Aug 26 Python
Python Numpy,mask图像的生成详解
Feb 19 Python
Python+redis通过限流保护高并发系统
Apr 15 Python
python实现PDF中表格转化为Excel的方法
Jun 16 Python
python多线程semaphore实现线程数控制的示例
Aug 10 Python
python使用matplotlib:subplot绘制多个子图的示例
Sep 24 Python
Django权限控制的使用
Jan 07 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
php array_filter除去数组中的空字符元素
2020/06/21 PHP
php常用正则函数实例小结
2016/12/29 PHP
JavaScript isPrototypeOf和hasOwnProperty使用区别
2010/03/04 Javascript
jQuery UI Datepicker length为空或不是对象错误的解决方法
2010/12/19 Javascript
JS实现图片预加载无需等待
2012/12/21 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
jQuery简单实现日历的方法
2015/05/04 Javascript
jquery实现定时自动轮播特效
2015/12/10 Javascript
JS函数arguments数组获得实际传参数个数的实现方法
2016/05/28 Javascript
JavaScript实现in-place思想的快速排序方法
2016/08/07 Javascript
Bootstrap 3.x打印预览背景色与文字显示异常的解决
2016/11/06 Javascript
Vue 2.0 服务端渲染入门介绍
2017/03/29 Javascript
为什么我们要做三份 Webpack 配置文件
2017/09/18 Javascript
Vue打包后出现一些map文件的解决方法
2018/02/13 Javascript
JS定时器如何实现提交成功提示功能
2020/06/12 Javascript
前端vue+elementUI如何实现记住密码功能
2020/09/20 Javascript
修改NPM全局模式的默认安装路径的方法
2020/12/15 Javascript
[41:05]Serenity vs Pain 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
python多线程编程中的join函数使用心得
2014/09/02 Python
centos6.7安装python2.7.11的具体方法
2017/01/16 Python
python3+PyQt5图形项的自定义和交互 python3实现page Designer应用程序
2020/07/20 Python
python单例模式实例解析
2018/08/28 Python
dpn网络的pytorch实现方式
2020/01/14 Python
python实现将字符串中的数字提取出来然后求和
2020/04/02 Python
几个数据库方面的面试题
2016/07/01 面试题
工商管理专业应届生求职信
2013/11/04 职场文书
竞聘演讲稿范文
2014/01/12 职场文书
2014年五一劳动节社区活动总结
2014/04/14 职场文书
《记金华的双龙洞》教学反思
2014/04/19 职场文书
校园环保标语
2014/06/13 职场文书
归元寺导游词
2015/02/06 职场文书
中国世界遗产导游词
2015/02/13 职场文书
李强感恩观后感
2015/06/17 职场文书
七年级英语教学反思
2016/02/15 职场文书
高三英语教学反思
2016/03/03 职场文书
导游词之安徽醉翁亭
2020/01/10 职场文书