pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串连接的N种方式总结
Sep 17 Python
Linux下通过python访问MySQL、Oracle、SQL Server数据库的方法
Apr 23 Python
Python的numpy库中将矩阵转换为列表等函数的方法
Apr 04 Python
Flask框架配置与调试操作示例
Jul 23 Python
使用pandas把某一列的字符值转换为数字的实例
Jan 29 Python
Python根据成绩分析系统浅析
Feb 11 Python
Python 微信爬虫完整实例【单线程与多线程】
Jul 06 Python
pytorch 模型可视化的例子
Aug 17 Python
基于Python 中函数的 收集参数 机制
Dec 21 Python
解决Pycharm中恢复被exclude的项目问题(pycharm source root)
Feb 14 Python
Django QuerySet查询集原理及代码实例
Jun 13 Python
Django中和时区相关的安全问题详解
Oct 12 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
mysql From_unixtime及UNIX_TIMESTAMP及DATE_FORMAT日期函数
2010/03/21 PHP
PHP教程之PHP中shell脚本的使用方法分享
2012/02/23 PHP
php中file_get_content 和curl以及fopen 效率分析
2014/09/19 PHP
ExtJS 2.2.1的grid控件在ie6中的显示问题
2009/05/04 Javascript
不要在cookie中使用特殊字符的原因分析
2010/07/13 Javascript
javascript怎么禁用浏览器后退按钮
2014/03/27 Javascript
简单方法判断JavaScript对象为null或者属性为空
2014/09/26 Javascript
Jquery中基本选择器用法实例详解
2015/05/18 Javascript
vue.js 获取当前自定义属性值
2017/06/01 Javascript
vue实现计算器功能
2020/02/22 Javascript
在Django的session中使用User对象的方法
2015/07/23 Python
Python实现批量压缩图片
2018/01/25 Python
Python实现的对本地host127.0.0.1主机进行扫描端口功能示例
2019/02/15 Python
python redis连接 有序集合去重的代码
2019/08/04 Python
python-Web-flask-视图内容和模板知识点西宁街
2019/08/23 Python
python实现复制文件到指定目录
2019/10/16 Python
Django项目uwsgi+Nginx保姆级部署教程实现
2020/04/19 Python
Python如何使用PIL Image制作GIF图片
2020/05/16 Python
python 实现读取csv数据,分类求和 再写进 csv
2020/05/18 Python
Keras 数据增强ImageDataGenerator多输入多输出实例
2020/07/03 Python
利用python为PostgreSQL的表自动添加分区
2021/01/18 Python
英国最大的手表网站:The Watch Hut
2017/03/31 全球购物
Street One瑞士:德国现代时装公司
2019/10/09 全球购物
你常见到的runtime exception
2016/09/05 面试题
优秀毕业生自我鉴定
2014/02/11 职场文书
金融管理应届生求职信
2014/02/20 职场文书
有趣的广告词
2014/03/18 职场文书
2014年幼儿园国庆主题活动方案
2014/09/16 职场文书
教育合作协议范本
2014/10/17 职场文书
项目转让协议书
2014/10/27 职场文书
学习普通话的体会
2014/11/07 职场文书
2014年仓库管理工作总结
2014/12/17 职场文书
酒店温馨提示语
2015/07/14 职场文书
2019年销售部季度工作计划3篇
2019/10/09 职场文书
Go语言特点及基本数据类型使用详解
2022/03/21 Golang
Pyhton爬虫知识之正则表达式详解
2022/04/01 Python