pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
使用Python编写Linux系统守护进程实例
Feb 03 Python
使用python编写监听端
Apr 12 Python
django将图片上传数据库后在前端显式的方法
May 25 Python
python得到windows自启动列表的方法
Oct 14 Python
python实现简单名片管理系统
Nov 30 Python
python自动发送测试报告邮件功能的实现
Jan 22 Python
在Pandas中DataFrame数据合并,连接(concat,merge,join)的实例
Jan 29 Python
Python及Pycharm安装方法图文教程
Aug 05 Python
python输出决策树图形的例子
Aug 09 Python
解决pycharm不能自动补全第三方库的函数和属性问题
Mar 12 Python
TensorFLow 数学运算的示例代码
Apr 21 Python
结束运行python的方法
Jun 16 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
php 操作excel文件的方法小结
2009/12/31 PHP
EXTJS FORM HIDDEN TEXTFIELD 赋值 使用value不好用的问题
2011/04/16 Javascript
用C/C++来实现 Node.js 的模块(一)
2014/09/24 Javascript
JavaScript中的cacheStorage使用详解
2015/07/29 Javascript
jQuery实现textarea自动增长宽高的方法
2015/12/18 Javascript
Bootstrap 布局组件(全)
2016/07/18 Javascript
JavaScript 拖拽实例代码
2016/09/21 Javascript
BootStrap tab选项卡使用小结
2020/08/09 Javascript
jQuery设置Easyui校验规则(推荐)
2016/11/21 Javascript
js中toString()和String()区别详解
2017/03/23 Javascript
Jquery中.bind()、.live()、.delegate()和.on()之间的区别详解
2017/08/01 jQuery
Webpack 之 babel-loader文件预处理器详解
2018/03/23 Javascript
30分钟快速入门掌握ES6/ES2015的核心内容(下)
2018/04/18 Javascript
jQuery实现图片下载代码
2019/07/18 jQuery
jquery选择器和属性对象的操作实例分析
2020/01/10 jQuery
node.js开发辅助工具nodemon安装与配置详解
2020/02/06 Javascript
详解JS深拷贝与浅拷贝
2020/08/04 Javascript
python中bisect模块用法实例
2014/09/25 Python
python实现linux下使用xcopy的方法
2015/06/28 Python
Python基础之getpass模块详细介绍
2017/08/10 Python
使用Python将Mysql的查询数据导出到文件的方法
2019/02/25 Python
中国领先的专业家电网购平台:国美在线
2016/12/25 全球购物
Exoticca英国:以最优惠的价格提供豪华异国情调旅行
2018/10/18 全球购物
荷兰网上药店:Drogisterij.net
2019/09/03 全球购物
RUIFIER官网:英国奢侈高级珠宝品牌
2020/06/12 全球购物
比较基础的php面试题及答案-编程题
2012/10/14 面试题
移动通信行业实习自我鉴定
2013/09/28 职场文书
公司业务主管岗位职责
2013/12/07 职场文书
高中毕业自我鉴定
2013/12/16 职场文书
旷课检讨书500字
2014/10/14 职场文书
2016年大学生社会实践心得体会
2015/10/09 职场文书
对Keras自带Loss Function的深入研究
2021/05/25 Python
详解CSS玩转图片Base64编码
2021/05/25 HTML / CSS
使用Pytorch训练two-head网络的操作
2021/05/28 Python
django 认证类配置实现
2021/11/11 Python
如何让你的Nginx支持分布式追踪详解
2022/07/07 Servers