pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python从有道词典网页获取单词翻译
Jul 03 Python
Python中django学习心得
Dec 06 Python
读取json格式为DataFrame(可转为.csv)的实例讲解
Jun 05 Python
Python简单爬虫导出CSV文件的实例讲解
Jul 06 Python
Python装饰器模式定义与用法分析
Aug 06 Python
对Python w和w+权限的区别详解
Jan 23 Python
Django项目中添加ldap登陆认证功能的实现
Apr 04 Python
python elasticsearch环境搭建详解
Sep 02 Python
小 200 行 Python 代码制作一个换脸程序
May 12 Python
Python csv文件记录流程代码解析
Jul 16 Python
python 装饰器重要在哪
Feb 14 Python
python 制作本地应用搜索工具
Feb 27 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
在PHP3中实现SESSION的功能(二)
2006/10/09 PHP
《PHP边学边教》(02.Apache+PHP环境配置――上篇)
2006/12/13 PHP
PHP swfupload图片上传的实例代码
2013/09/30 PHP
给WordPress中的留言加上楼层号的PHP代码实例
2015/12/14 PHP
php图片上传类 附调用方法
2016/05/15 PHP
ThinkPHP的SAE开发相关注意事项详解
2016/10/09 PHP
javascript引用对象的方法
2007/01/11 Javascript
dojo 之基础篇(二)之从服务器读取数据
2007/03/24 Javascript
JavaScript 页面编码与浏览器类型判断代码
2010/06/03 Javascript
js 处理URL实用技巧
2010/11/23 Javascript
使用JavaScript开发IE浏览器本地插件实例
2015/02/18 Javascript
JQuery中ajax方法访问web服务实例
2015/07/18 Javascript
js 模仿锚点定位的实现方法
2016/11/19 Javascript
详解基于angular-cli配置代理解决跨域请求问题
2017/07/05 Javascript
深入讲解xhr(XMLHttpRequest)/jsonp请求之abort
2017/07/26 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
webpack构建react多页面应用详解
2017/09/15 Javascript
3种vue组件的书写形式
2017/11/29 Javascript
tsconfig.json配置详解
2019/05/17 Javascript
[34:44]Liquid vs TNC Supermajor 胜者组 BO3 第二场 6.4
2018/06/05 DOTA
python中提高pip install速度
2020/02/14 Python
Selenium alert 弹窗处理的示例代码
2020/08/06 Python
Python常用断言函数实例汇总
2020/11/30 Python
HTML5页面音视频在微信和app下自动播放的实现方法
2016/10/20 HTML / CSS
中国海淘族值得信赖的海淘返利网站:55海淘
2017/01/16 全球购物
Tom Dixon官网:英国照明及家具设计和制造公司
2019/03/01 全球购物
JD Sports丹麦:英国领先的运动时尚零售商
2020/11/24 全球购物
PyQt QMainWindow的使用示例
2021/03/24 Python
公司同意接收函
2014/01/13 职场文书
通用自荐信范文
2014/03/14 职场文书
物业总经理助理岗位职责
2014/06/29 职场文书
党的群众路线教育实践活动实施方案
2014/10/31 职场文书
房地产工程部经理岗位职责
2015/04/09 职场文书
钢琴师观后感
2015/06/12 职场文书
毕业典礼主持词
2015/06/29 职场文书
pytorch交叉熵损失函数的weight参数的使用
2021/05/24 Python