解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python创建和使用字典实例详解
Nov 01 Python
Python创建xml的方法
Mar 10 Python
利用Python中的pandas库对cdn日志进行分析详解
Mar 07 Python
Python cookbook(数据结构与算法)字典相关计算问题示例
Feb 18 Python
python实现雨滴下落到地面效果
Jun 21 Python
Django migrations 默认目录修改的方法教程
Sep 28 Python
python dlib人脸识别代码实例
Apr 04 Python
python反编译学习之字节码详解
May 19 Python
Python多叉树的构造及取出节点数据(treelib)的方法
Aug 09 Python
利用Python的sympy包求解一元三次方程示例
Nov 22 Python
keras实现多种分类网络的方式
Jun 11 Python
python3 删除所有自定义变量的操作
Apr 08 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
从网上搜到的phpwind 0day的代码
2006/12/07 PHP
PHP中10个不常见却非常有用的函数
2010/03/21 PHP
用php来改写404错误页让你的页面更友好
2013/01/24 PHP
PHP导航下拉菜单的实现如此简单
2013/09/22 PHP
PHP的openssl加密扩展使用小结(推荐)
2016/07/18 PHP
Linux平台PHP5.4设置FPM线程数量的方法
2016/11/09 PHP
php设计模式之代理模式分析【星际争霸游戏案例】
2020/03/23 PHP
js/jQuery简单实现选项卡功能
2014/01/02 Javascript
JavaScript中的操作符==与===介绍
2014/12/31 Javascript
JQuery节点元素属性操作方法
2015/06/11 Javascript
AJAX和jQuery动态加载数据的实现方法
2016/12/05 Javascript
JavaScript 输出显示内容(document.write、alert、innerHTML、console.log)
2016/12/14 Javascript
EasyUI创建人员树的实例代码
2017/09/15 Javascript
layui中layer前端组件实现图片显示功能的方法分析
2017/10/13 Javascript
vue实现商城购物车功能
2017/11/27 Javascript
canvas轨迹回放功能实现
2017/12/20 Javascript
解决vue v-for 遍历循环时key值报错的问题
2018/09/06 Javascript
微信小程序中显示倒计时代码实例
2019/05/09 Javascript
创建nuxt.js项目流程图解
2020/03/13 Javascript
[00:53]TI3正赛第三天 DK怒破A队不败金身 现场国旗飘扬热血激昂
2013/08/10 DOTA
Python 正则表达式入门(初级篇)
2016/12/07 Python
Python2实现的LED大数字显示效果示例
2017/09/04 Python
Python 使用folium绘制leaflet地图的实现方法
2019/07/05 Python
简单了解python代码优化小技巧
2019/07/08 Python
匡威意大利官方商店 :Converse意大利
2018/11/27 全球购物
DOUGLAS荷兰:购买香水和化妆品
2020/10/24 全球购物
什么是类的返射机制
2016/02/06 面试题
公共事业管理本科生求职信
2013/10/07 职场文书
幼教个人求职信范文
2013/12/02 职场文书
制药工程专业毕业生推荐信
2013/12/24 职场文书
《望洞庭》教学反思
2014/02/16 职场文书
音乐教学随笔感言
2014/02/19 职场文书
开会通知
2015/04/20 职场文书
运动会宣传语
2015/07/13 职场文书
Python文件的操作示例的详细讲解
2021/04/08 Python
给numpy.array增加维度的超简单方法
2021/06/02 Python