pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
Python如何快速实现分布式任务
Jul 06 Python
Python实现对字典分别按键(key)和值(value)进行排序的方法分析
Dec 19 Python
5款Python程序员高频使用开发工具推荐
Apr 10 Python
python中时间、日期、时间戳的转换的实现方法
Jul 06 Python
python3实现mysql导出excel的方法
Jul 31 Python
简单了解django文件下载方式
Feb 10 Python
Macbook安装Python最新版本、GUI开发环境、图像处理、视频处理环境详解
Feb 17 Python
python requests包的request()函数中的参数-params和data的区别介绍
May 05 Python
python中判断文件结束符的具体方法
Aug 04 Python
Ubuntu16安装Python3.9的实现步骤
Dec 15 Python
python绘图subplots函数使用模板的示例代码
Apr 30 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
手把手教你使用DedeCms V3的在线采集图文教程
2007/04/03 PHP
Destoon实现多表查询示例
2014/08/21 PHP
PHP7.0安装笔记整理
2015/08/28 PHP
PHPExcel笔记, mpdf导出
2016/05/03 PHP
php版微信公众账号第三方管理工具开发简明教程
2016/09/23 PHP
如何利用预加载优化Laravel Model查询详解
2017/08/11 PHP
刷新时清空文本框内容的js代码
2007/04/23 Javascript
js 编程笔记 无名函数
2011/06/28 Javascript
javascript动态生成树形菜单的方法
2015/11/14 Javascript
Javascript 获取鼠标当前的位置实现方法
2016/10/27 Javascript
vue.js指令v-model实现方法
2016/12/05 Javascript
jquery插件bootstrapValidator表单验证详解
2016/12/15 Javascript
JS实现的判断方法、变量是否存在功能示例
2020/03/28 Javascript
VUE中setTimeout和setInterval自动销毁案例
2020/09/07 Javascript
[07:12]2014DOTA2西雅图国际邀请赛 黑马Liquid专题采访
2014/07/12 DOTA
Python爬虫DNS解析缓存方法实例分析
2017/06/02 Python
Python中sort和sorted函数代码解析
2018/01/25 Python
Python subprocess模块详细解读
2018/01/29 Python
使用Python制作自动推送微信消息提醒的备忘录功能
2018/09/06 Python
Python log模块logging记录打印用法解析
2020/01/20 Python
Python高并发和多线程有什么关系
2020/11/14 Python
python如何获得list或numpy数组中最大元素对应的索引
2020/11/16 Python
Python 打印自己设计的字体的实例讲解
2021/01/04 Python
如何处理简单的PHP错误
2015/10/14 面试题
学生自我鉴定
2013/12/18 职场文书
大学生旷课检讨书
2014/01/22 职场文书
建筑专业毕业生自荐信
2014/05/25 职场文书
工业设计专业自荐书
2014/06/05 职场文书
综治工作心得体会
2014/09/11 职场文书
临床医学生职业规划书范文
2014/10/25 职场文书
交通事故死亡赔偿协议书
2014/12/03 职场文书
中秋客户感谢信
2015/01/22 职场文书
售后前台接待岗位职责
2015/04/03 职场文书
防溺水主题班会教案
2015/08/12 职场文书
浅析NIO系列之TCP
2021/06/15 Java/Android
JavaScript流程控制(分支)
2021/12/06 Javascript