在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.3教程之模拟百度登陆代码分享
Jan 16 Python
python处理文本文件实现生成指定格式文件的方法
Jul 31 Python
5种Python单例模式的实现方式
Jan 14 Python
Python中的复制操作及copy模块中的浅拷贝与深拷贝方法
Jul 02 Python
Windows下python3.6.4安装教程
Jul 31 Python
Django使用AJAX调用自己写的API接口的方法
Mar 06 Python
django数据关系一对多、多对多模型、自关联的建立
Jul 24 Python
win10子系统python开发环境准备及kenlm和nltk的使用教程
Oct 14 Python
Python简单实现词云图代码及步骤解析
Jun 04 Python
关于探究python中sys.argv时遇到的问题详解
Feb 23 Python
Pandas搭配lambda组合使用详解
Jan 22 Python
Python中tqdm的使用和例子
Sep 23 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
Drupal 添加模块出现莫名其妙的错误的解决方法(往往出现在模块较多时)
2011/04/18 PHP
php解析url的三个示例
2014/01/20 PHP
PHP实现使用优酷土豆视频地址获取swf播放器分享地址
2014/06/05 PHP
PHP实现生成透明背景的PNG缩略图函数分享
2014/07/08 PHP
PHP结合jQuery实现找回密码
2015/07/22 PHP
动态表单验证的操作方法和TP框架里面的ajax表单验证
2017/07/19 PHP
一个js实现的所谓的滑动门
2007/05/23 Javascript
JS event使用方法详解
2008/04/28 Javascript
jQuery插件实现屏蔽单个元素使用户无法点击
2013/04/12 Javascript
javascript 密码框防止用户粘贴和复制的实现代码
2014/02/17 Javascript
jquery动态加载js/css文件方法(自写小函数)
2014/10/11 Javascript
JavaScript检测实例属性, 原型属性
2015/02/04 Javascript
jQuery插件之Tocify动态节点目录菜单生成器附源码下载
2016/01/08 Javascript
原生JS:Date对象全面解析
2016/09/06 Javascript
vue使用axios跨域请求数据问题详解
2017/10/18 Javascript
Vue项目路由刷新的实现代码
2019/04/17 Javascript
javascript的this关键字详解
2019/05/20 Javascript
Vue 嵌套路由使用总结(推荐)
2020/01/13 Javascript
Vue $attrs & inheritAttr实现button禁用效果案例
2020/12/07 Vue.js
[04:29]DOTA2亚洲邀请赛小组赛第一日 TOP10精彩集锦
2015/02/01 DOTA
详解Django+Uwsgi+Nginx的生产环境部署
2018/06/25 Python
Django 多语言教程的实现(i18n)
2018/07/07 Python
python使用requests.session模拟登录
2019/08/09 Python
python-numpy-指数分布实例详解
2019/12/07 Python
通过Python实现一个简单的html页面
2020/05/16 Python
CSS3 实现的加载动画
2020/12/07 HTML / CSS
详解使用canvas保存网页为pdf文件支持跨域
2018/11/23 HTML / CSS
HTML5拖拽文件到浏览器并实现文件上传下载功能代码
2013/06/06 HTML / CSS
管理失职检讨书
2014/02/12 职场文书
内刊编辑求职自荐书范文
2014/02/19 职场文书
大学毕业生推荐信
2014/07/09 职场文书
土地转让协议书
2014/09/27 职场文书
2014年高校辅导员工作总结
2014/12/09 职场文书
教师节感想
2015/08/11 职场文书
2016年优秀教师先进事迹材料
2016/02/26 职场文书
Python爬取奶茶店数据分析哪家最好喝以及性价比
2022/09/23 Python