PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python实现读取目录所有文件的文件名并保存到txt文件代码
Nov 22 Python
使用Python脚本对Linux服务器进行监控的教程
Apr 02 Python
Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)
Apr 28 Python
Python爬取APP下载链接的实现方法
Sep 30 Python
python数据清洗系列之字符串处理详解
Feb 12 Python
详解Python import方法引入模块的实例
Aug 02 Python
Python语言实现将图片转化为html页面
Dec 06 Python
Python3实现的简单验证码识别功能示例
May 02 Python
Django管理员账号和密码忘记的完美解决方法
Dec 06 Python
python脚本之一键移动自定格式文件方法实例
Sep 02 Python
python+requests实现接口测试的完整步骤
Oct 27 Python
python批量生成身份证号到Excel的两种方法实例
Jan 14 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
咖啡磨器 如何选购一台适合家用的意式磨豆机
2021/03/05 新手入门
php实现ip白名单黑名单功能
2015/03/12 PHP
PHP中soap用法示例【SoapServer服务端与SoapClient客户端编写】
2018/12/25 PHP
Javascript 八进制转义字符(8进制)
2011/04/08 Javascript
JS 实现点击a标签的时候让其背景更换
2013/10/15 Javascript
javascript与jquery中跳出循环的区别总结
2013/11/04 Javascript
jquery插件之定时查询待处理任务数量
2014/05/01 Javascript
jquery实现动画菜单的左右滚动、渐变及图形背景滚动等效果
2015/08/25 Javascript
jQuery 1.9.1源码分析系列(十五)之动画处理
2015/12/03 Javascript
浅析jQuery移动开发中内联按钮和分组按钮的编写
2015/12/04 Javascript
JavaScript驾驭网页-CSS与DOM
2016/03/24 Javascript
用JS写的一个Ajax库(实例代码)
2016/08/06 Javascript
Bootstrap免费字体和图标网站(值得收藏)
2017/03/16 Javascript
Node.js中看JavaScript的引用
2017/04/22 Javascript
利用JS对iframe父子(内外)页面进行操作的方法教程
2017/06/15 Javascript
JavaScript类的继承方法小结【组合继承分析】
2018/07/11 Javascript
vue调试工具vue-devtools安装及使用方法
2018/11/07 Javascript
vue两组件间值传递 $router.push实现方法
2019/05/15 Javascript
仿iPhone通讯录制作小程序自定义选择组件的实现
2019/05/23 Javascript
小程序实现锚点滑动效果
2019/09/23 Javascript
[02:29]大剑、皮鞭、女装,这届DOTA2勇士令状里都有
2020/07/17 DOTA
python pycurl验证basic和digest认证的方法
2018/05/02 Python
PowerBI和Python关于数据分析的对比
2019/07/11 Python
Python 实现向word(docx)中输出
2020/02/13 Python
推荐8款常用的Python GUI图形界面开发框架
2020/02/23 Python
高清屏下canvas重置尺寸引发的问题的解决
2019/10/14 HTML / CSS
命名空间(namespace)和程序集(Assembly)有什么区别
2015/09/25 面试题
Boolean b = new Boolean(“abcde”); 会编译错误码
2013/11/27 面试题
护士毕业生自荐信
2014/02/07 职场文书
个人贷款授权委托书样本
2014/10/07 职场文书
二年级语文上册复习计划
2015/01/19 职场文书
上班迟到检讨书范文
2015/05/06 职场文书
外出考察学习心得体会
2016/01/18 职场文书
Python中的min及返回最小值索引的操作
2021/05/10 Python
pytorch 6 batch_train 批训练操作
2021/05/28 Python
SQL Server中搜索特定的对象
2022/05/25 SQL Server