pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
说一说Python logging
Apr 15 Python
Python编程实现控制cmd命令行显示颜色的方法示例
Aug 14 Python
Python模拟简单电梯调度算法示例
Aug 20 Python
python leetcode 字符串相乘实例详解
Sep 03 Python
python的继承知识点总结
Dec 10 Python
python实现移位加密和解密
Mar 22 Python
使用python PIL库实现简单验证码的去噪方法步骤
May 10 Python
Python:Numpy 求平均向量的实例
Jun 29 Python
python使用pip安装模块出现ReadTimeoutError: HTTPSConnectionPool的解决方法
Oct 04 Python
matplotlib之多边形选区(PolygonSelector)的使用
Feb 24 Python
浅谈python数据类型及其操作
May 25 Python
python 爬取吉首大学网站成绩单
Jun 02 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
关于尾递归的使用详解
2013/05/02 PHP
搭建自己的PHP MVC框架详解
2017/08/16 PHP
表单提交验证类
2006/07/14 Javascript
JavaScript 字符串处理函数使用小结
2010/12/02 Javascript
jQuery代码优化之基本事件
2011/11/01 Javascript
JQuery中模拟image的ajaxPrefilter与ajaxTransport处理
2015/06/19 Javascript
NodeJS的Promise的用法解析
2016/05/05 NodeJs
Node.js数据库操作之查询MySQL数据库(二)
2017/03/04 Javascript
Javascript封装id、class与元素选择器方法示例
2017/03/13 Javascript
Vim快速合并行及vim 将文件所有行合并到一行
2017/11/27 Javascript
浅谈Vue2.0父子组件间事件派发机制
2018/01/08 Javascript
vue+高德地图实现地图搜索及点击定位操作
2020/09/09 Javascript
[06:43]DAC2018 4.5 SOLO赛 Maybe vs Paparazi
2018/04/06 DOTA
linux下安装easy_install的方法
2013/02/10 Python
跟老齐学Python之总结参数的传递
2014/10/10 Python
浅谈使用Python变量时要避免的3个错误
2017/10/30 Python
python简单图片操作:打开\显示\保存图像方法介绍
2017/11/23 Python
django中模板的html自动转意方法
2018/05/27 Python
python 随机森林算法及其优化详解
2019/07/11 Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
2020/09/21 Python
终端业务员岗位职责
2013/11/27 职场文书
大学生工作推荐信范文
2013/12/02 职场文书
大客户销售经理职责
2013/12/04 职场文书
酒吧副总经理岗位职责
2013/12/10 职场文书
市政施工员自我鉴定
2014/01/15 职场文书
《走一步再走一步》教学反思
2014/02/15 职场文书
高三学习决心书
2014/03/11 职场文书
会计工作决心书
2014/03/11 职场文书
艾滋病宣传标语
2014/06/25 职场文书
十佳党员事迹材料
2014/08/28 职场文书
“三支一扶”支教教师思想汇报
2014/09/13 职场文书
警察群众路线整改措施
2014/09/26 职场文书
省委召开党的群众路线教育实践活动总结大会报告
2014/10/21 职场文书
周年庆典答谢词
2015/01/20 职场文书
谁动了我的奶酪读书笔记
2015/06/30 职场文书
如何更改Win11声音输出设备?Win11声音输出设备四种更改方法
2022/04/08 数码科技