pytorch中的model=model.to(device)使用说明


Posted in Python onMay 24, 2021

这代表将模型加载到指定设备上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。

当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。

将由GPU保存的模型加载到CPU上。

将torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由CPU保存的模型加载到GPU上。

确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。

最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

补充:pytorch中model.to(device)和map_location=device的区别

一、简介

在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。

将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。

调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。

二、实例

了解了两者代表的意义,以下介绍两者的使用。

1、保存在GPU上,在CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解释:

在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。

2、保存在GPU上,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。

此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。

请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不会覆盖 my_tensor。

因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解释:

在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。

这会将模型加载到给定的GPU设备。

接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。

最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。

请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不会覆盖my_tensor。

因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的keyword模块用法实例分析
Jun 30 Python
分析python切片原理和方法
Dec 19 Python
Python使用pymysql从MySQL数据库中读出数据的方法
Jul 25 Python
python点击鼠标获取坐标(Graphics)
Aug 10 Python
python提取照片坐标信息的实例代码
Aug 14 Python
python nmap实现端口扫描器教程
May 28 Python
python3图片文件批量重命名处理
Oct 31 Python
Python实现AI自动抠图实例解析
Mar 05 Python
Django静态资源部署404问题解决方案
May 11 Python
一文轻松掌握python语言命名规范规则
Jun 18 Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 Python
python实现杨辉三角的几种方法代码实例
Mar 02 Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
Django实现翻页的示例代码
May 24 #Python
pytorch--之halfTensor的使用详解
pandas DataFrame.shift()函数的具体使用
May 24 #Python
教你怎么用python实现字符串转日期
May 24 #Python
You might like
ThinkPHP惯例配置文件详解
2014/07/14 PHP
PHP运用foreach神奇的转换数组(实例讲解)
2018/02/01 PHP
Swoole 5将移除自动添加Event::wait()特性详解
2019/07/10 PHP
Yii框架使用PHPExcel导出Excel文件的方法分析【改进版】
2019/07/24 PHP
用JS实现一个TreeMenu效果分享
2011/08/28 Javascript
关于jQuery新的事件绑定机制on()的使用技巧
2013/04/26 Javascript
js关于命名空间的函数实例
2015/02/05 Javascript
js兼容pc端浏览器并有多种弹出小提示的手机端浮层控件实例
2015/04/29 Javascript
浅谈window.onbeforeunload() 事件调用ajax
2016/06/29 Javascript
js前端解决跨域问题的8种方案(最新最全)
2016/11/18 Javascript
使用ionic切换页面卡顿的解决方法
2016/12/16 Javascript
element-ui循环显示radio控件信息的方法
2018/08/24 Javascript
深入浅析Vue 中 ref 的使用
2019/04/29 Javascript
JavaScrip如果基于url实现图片下载
2020/07/03 Javascript
Python AES加密实例解析
2018/01/18 Python
python的Crypto模块实现AES加密实例代码
2018/01/22 Python
python 把文件中的每一行以数组的元素放入数组中的方法
2018/04/29 Python
编写多线程Python服务器 最适合基础
2018/09/14 Python
对python字典过滤条件的实例详解
2019/01/22 Python
在OpenCV里实现条码区域识别的方法示例
2019/12/04 Python
python要安装在哪个盘
2020/06/15 Python
使用canvas生成含有微信头像的邀请海报没有微信头像问题
2019/10/29 HTML / CSS
南非最受欢迎的时尚品牌:MRP
2016/09/18 全球购物
英国最受欢迎的价格比较网站之一:MoneySuperMarket
2018/12/19 全球购物
alice McCALL官网:澳大利亚时尚品牌
2020/11/16 全球购物
HttpServlet类中的主要方法都有哪些?各自的作用是什么?
2014/03/16 面试题
批评与自我批评材料
2014/02/15 职场文书
市场营销求职信范文
2014/02/21 职场文书
历史学专业求职信
2014/06/19 职场文书
幼儿园门卫岗位职责范本
2014/07/02 职场文书
教师师德考核自我评价
2014/09/13 职场文书
2014年“四风”问题个人整改措施
2014/09/17 职场文书
锦旗赠语
2015/06/23 职场文书
SpringAop日志找不到方法的处理
2021/06/21 Java/Android
利用Matlab绘制各类特殊图形的实例代码
2021/07/16 Python
详解Spring Bean的配置方式与实例化
2022/06/10 Java/Android