pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中无限元素列表的实现方法
Aug 18 Python
python生成词云的实现方法(推荐)
Jun 13 Python
LRUCache的实现原理及利用python实现的方法
Nov 21 Python
Python3实现的Mysql数据库操作封装类
Jun 06 Python
python学生管理系统开发
Jan 30 Python
django2.0扩展用户字段示例
Feb 13 Python
django的ORM模型的实现原理
Mar 04 Python
Python 中list ,set,dict的大规模查找效率对比详解
Oct 11 Python
python获取array中指定元素的示例
Nov 26 Python
pytorch::Dataloader中的迭代器和生成器应用详解
Jan 03 Python
VS2019+python3.7+opencv4.1+tensorflow1.13配置详解
Apr 16 Python
在python下实现word2vec词向量训练与加载实例
Jun 09 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
在Windows下编译适用于PHP 5.2.12及5.2.13的eAccelerator.dll(附下载)
2010/05/04 PHP
php广告加载类用法实例
2014/09/23 PHP
Nigma vs Alliance BO5 第五场2.14
2021/03/10 DOTA
Firefox下提示illegal character并出现乱码的原因
2010/03/25 Javascript
JS中如何设置readOnly的值
2013/12/25 Javascript
js获取字符串最后一位方法汇总
2014/11/13 Javascript
node.js中的favicon.ico请求问题处理
2014/12/15 Javascript
JQuery中DOM加载与事件执行实例分析
2015/06/13 Javascript
js实现tab切换效果实例
2015/09/16 Javascript
JavaScript实现上下浮动的窗口效果代码
2015/10/12 Javascript
jQuery实现带延时功能的水平多级菜单效果【附demo源码下载】
2016/09/21 Javascript
Bootstrap3 多选和单选框(checkbox)
2016/12/29 Javascript
详解vue 模版组件的三种用法
2017/07/21 Javascript
jQuery的时间datetime控件在AngularJs中的使用实例(分享)
2017/08/17 jQuery
AngularJS中重新加载当前路由页面的方法
2018/03/09 Javascript
Vue插槽原理与用法详解
2019/03/05 Javascript
js中调用微信的扫描二维码功能的实现代码
2020/04/11 Javascript
[02:51]DOTA2英雄基础教程 风暴之灵
2013/12/23 DOTA
[05:10]2014DOTA2国际邀请赛 通往胜利之匙赛场探秘之旅
2014/07/18 DOTA
python 不关闭控制台的实现方法
2011/10/23 Python
python检测服务器是否正常
2014/02/16 Python
在Python的Django框架中加载模版的方法
2015/07/16 Python
Python爬虫小技巧之伪造随机的User-Agent
2018/09/13 Python
Python3.5 Pandas模块缺失值处理和层次索引实例详解
2019/04/23 Python
django创建最简单HTML页面跳转方法
2019/08/16 Python
Python urllib2运行过程原理解析
2020/06/04 Python
opencv+pyQt5实现图片阈值编辑器/寻色块阈值利器
2020/11/13 Python
python 写一个水果忍者游戏
2021/01/13 Python
找到您丢失的钥匙、钱包和手机:Tile
2017/05/19 全球购物
教师演讲稿大全
2014/05/16 职场文书
法人代表证明书
2014/09/18 职场文书
资源环境与城乡规划管理专业自荐书
2014/09/26 职场文书
2016关于预防职务犯罪的心得体会
2016/01/21 职场文书
爱岗敬业先进典型事迹材料(2016推荐版)
2016/02/26 职场文书
PostgreSQL自动更新时间戳实例代码
2021/11/27 PostgreSQL
Redis命令处理过程源码解析
2022/02/12 Redis