PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
pymongo实现多结果进行多列排序的方法
May 16 Python
Linux下将Python的Django项目部署到Apache服务器
Dec 24 Python
Tensorflow之构建自己的图片数据集TFrecords的方法
Feb 07 Python
Django自定义过滤器定义与用法示例
Mar 22 Python
基于pip install django失败时的解决方法
Jun 12 Python
Python用于学习重要算法的模块pygorithm实例浅析
Aug 16 Python
python读取并写入mat文件的方法
Jul 12 Python
python 安装impala包步骤
Mar 28 Python
Python urllib.request对象案例解析
May 11 Python
Python3安装模块报错Microsoft Visual C++ 14.0 is required的解决方法
Jul 28 Python
简述python四种分词工具,盘点哪个更好用?
Apr 13 Python
Python语言中的数据类型-序列
Feb 24 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
使用Apache的rewrite技术
2006/06/22 PHP
Linux fgetcsv取得的数组元素为空字符串的解决方法
2011/11/25 PHP
PHP将整个网站生成HTML纯静态网页的方法总结
2012/02/05 PHP
PHP中“简单工厂模式”实例代码讲解
2012/09/04 PHP
让codeigniter与swfupload整合的最佳解决方案
2014/06/12 PHP
php 生成Tab键或逗号分隔的CSV
2016/09/24 PHP
利用PHP判断文件是否为图片的方法总结
2017/01/06 PHP
PHP实现提高SESSION响应速度的几种方法详解
2019/08/09 PHP
php反序列化长度变化尾部字符串逃逸(0CTF-2016-piapiapia)
2020/02/15 PHP
JavaScript格式化日期时间的方法和自定义格式化函数示例
2014/04/04 Javascript
jQuery中[attribute^=value]选择器用法实例
2014/12/31 Javascript
JS+CSS实现Li列表隔行换色效果的方法
2015/02/16 Javascript
jQuery实现鼠标滑过链接控制图片的滑动展开与隐藏效果
2015/10/28 Javascript
Vue.js 递归组件实现树形菜单(实例分享)
2016/12/21 Javascript
jQuery Validate 数组 全部验证问题
2017/01/12 Javascript
原生js实现倒计时--2018
2017/02/21 Javascript
div中文字内容溢出常见的解决方法
2017/03/16 Javascript
微信小程序 chooseImage选择图片或者拍照
2017/04/07 Javascript
vue+mockjs模拟数据实现前后端分离开发的实例代码
2017/08/08 Javascript
在vue-cli中组件通信的方法
2017/12/16 Javascript
js实现以最简单的方式将数组元素添加到对象中的方法
2017/12/20 Javascript
Node.js 利用cheerio制作简单的网页爬虫示例
2018/03/01 Javascript
vue 录制视频并压缩视频文件的方法
2018/07/27 Javascript
vue favicon设置以及动态修改favicon的方法
2018/12/21 Javascript
vue-cli3使用mock数据的方法分析
2020/03/16 Javascript
[04:26]DOTA2上海特锦赛小组赛第二日 TOP10精彩集锦
2016/02/27 DOTA
python实现稀疏矩阵示例代码
2017/06/09 Python
使用matplotlib画散点图的方法
2018/05/25 Python
Python中使用遍历在列表中添加字典遇到的坑
2019/02/27 Python
python处理“
2019/06/10 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
2020/06/16 Python
Marriott中国:万豪国际酒店查询预订
2016/09/02 全球购物
公司总经理工作职责管理办法
2014/02/28 职场文书
合作协议书
2014/04/23 职场文书
全国税务系统先进集体事迹材料
2014/05/19 职场文书
债务纠纷委托书
2014/08/30 职场文书