MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 算法 排序实现快速排序
Jun 05 Python
Python的socket模块源码中的一些实现要点分析
Jun 06 Python
Python实现批量更换指定目录下文件扩展名的方法
Sep 19 Python
使用Python3 编写简单信用卡管理程序
Dec 21 Python
python跳过第一行快速读取文件内容的实例
Jul 12 Python
在Python中构建增广矩阵的实现方法
Jul 01 Python
python3的数据类型及数据类型转换实例详解
Aug 20 Python
Python3搭建http服务器的实现代码
Feb 11 Python
python GUI库图形界面开发之PyQt5表单布局控件QFormLayout详细使用方法与实例
Mar 06 Python
使用Keras建立模型并训练等一系列操作方式
Jul 02 Python
python 利用matplotlib在3D空间中绘制平面的案例
Feb 06 Python
pycharm代码删除恢复的方法
Jun 26 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
星际争霸中的热键
2020/03/04 星际争霸
英雄试炼之肉山谷—引领RPG新潮流
2020/04/20 DOTA
实例介绍PHP中zip_open()函数用法
2019/02/15 PHP
详解提高使用Java反射的效率方法
2019/04/29 PHP
javascript操作文本框readOnly
2007/05/15 Javascript
Jquery 实现Tab效果 思路是js思路
2010/03/02 Javascript
动态的改变IFrame的高度实现IFrame自动伸展适应高度
2012/12/28 Javascript
jQuery对html元素取值与赋值的方法
2013/11/20 Javascript
原生javascript实现拖动元素示例代码
2014/09/01 Javascript
javascript中的五种基本数据类型
2015/08/26 Javascript
利用jQuery及AJAX技术定时更新GridView的某一列数据
2015/12/04 Javascript
基于JavaScript实现右键菜单和拖拽功能
2016/11/28 Javascript
JavaScript定义全局对象的方法示例
2017/01/12 Javascript
jQuery模拟淘宝购物车功能
2017/02/27 Javascript
微信小程序云开发之模拟后台增删改查
2019/05/16 Javascript
selenium 反爬虫之跳过淘宝滑块验证功能的实现代码
2020/08/27 Javascript
[02:36]DOTA2英雄基础教程 斯拉克
2013/11/29 DOTA
Python文件夹与文件的操作实现代码
2014/07/13 Python
Python实现公历(阳历)转农历(阴历)的方法示例
2017/08/22 Python
python查看模块,对象的函数方法
2018/10/16 Python
Python文件循环写入行时防止覆盖的解决方法
2018/11/09 Python
python使用udp实现聊天器功能
2018/12/10 Python
Python+OpenCV实现实时眼动追踪的示例代码
2019/11/11 Python
Python实现自动装机功能案例分析
2020/10/22 Python
css3 media 响应式布局的简单实例
2016/08/03 HTML / CSS
Jacadi Paris英国官网:法国童装品牌
2019/08/09 全球购物
介绍一下MYSQL常用的优化技巧
2012/10/25 面试题
工商干部先进事迹
2014/05/14 职场文书
社区志愿者培训方案
2014/06/10 职场文书
超市周年庆活动方案
2014/08/16 职场文书
光学与应用专业毕业生求职信
2014/09/01 职场文书
流动人口婚育证明范本
2014/09/26 职场文书
运动会搞笑广播稿
2014/10/14 职场文书
2014年设备管理工作总结
2014/11/26 职场文书
实习生个人总结范文
2015/02/28 职场文书
html5移动端禁止长按图片保存的实现
2021/04/20 HTML / CSS