pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用redis pool的一种单例实现方式
Apr 16 Python
Python类的动态修改的实例方法
Mar 24 Python
Python3安装Pymongo详细步骤
May 26 Python
python使用pycharm环境调用opencv库
Feb 11 Python
Python 查找list中的某个元素的所有的下标方法
Jun 27 Python
Python迭代器与生成器用法实例分析
Jul 09 Python
Python多项式回归的实现方法
Mar 11 Python
Django项目中添加ldap登陆认证功能的实现
Apr 04 Python
Django重置migrations文件的方法步骤
May 01 Python
python实现图片上添加图片
Nov 26 Python
python 安装库几种方法之cmd,anaconda,pycharm详解
Apr 08 Python
Python小白垃圾回收机制入门
Jun 09 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
基于header的一些常用指令详解
2013/06/06 PHP
PHP return语句另类用法不止是在函数中
2014/09/17 PHP
php实现字符串首字母大写和单词首字母大写的方法
2015/03/14 PHP
php类常量用法实例分析
2015/07/09 PHP
PHP登录验证码的实现与使用方法
2016/07/07 PHP
php reset() 函数指针指向数组中的第一个元素并输出实例代码
2016/11/21 PHP
PHP实现的Redis多库选择功能单例类
2017/07/27 PHP
PHP date_default_timezone_set()设置时区操作实例分析
2020/05/16 PHP
一段非常简单的让图片自动切换js代码
2006/11/10 Javascript
Javascript解决常见浏览器兼容问题的12种方法
2010/01/04 Javascript
jquery $.ajax相关用法分享
2012/03/16 Javascript
jquery遍历筛选数组的几种方法和遍历解析json对象
2013/12/13 Javascript
javascript中typeof的使用示例
2013/12/19 Javascript
javascript操作referer详细解析
2014/03/10 Javascript
javascript中HTMLDOM操作详解
2014/12/11 Javascript
浅谈JSON.parse()和JSON.stringify()
2015/07/14 Javascript
jquery实现先淡出再折叠收起的动画效果
2015/08/07 Javascript
详解Angular 4.x NgTemplateOutlet
2017/05/24 Javascript
nodeJS模块简单用法示例
2018/04/21 NodeJs
vue router总结 $router和$route及router与 router与route区别
2019/07/05 Javascript
layui表单提交到后台自动封装到实体类的方法
2019/09/12 Javascript
JavaScript实现Excel表格效果
2020/02/07 Javascript
解决Mint-ui 框架Popup和Datetime Picker组件滚动穿透的问题
2020/11/04 Javascript
vue+iview实现文件上传
2020/11/17 Vue.js
[11:01]2014DOTA2西雅图邀请赛 冷冷带你探秘威斯汀
2014/07/08 DOTA
零基础写python爬虫之urllib2使用指南
2014/11/05 Python
python通过wxPython打开一个音频文件并播放的方法
2015/03/25 Python
django反向解析和正向解析的方式
2018/06/05 Python
Python 变量类型详解
2018/10/10 Python
在IPython中执行Python程序文件的示例
2018/11/01 Python
pytorch查看torch.Tensor和model是否在CUDA上的实例
2020/01/03 Python
英国家居用品和家居装饰品购物网站:Cox & Cox
2019/08/25 全球购物
泰国排名第一的家居用品中心:HomePro
2020/11/18 全球购物
参观监狱心得体会
2014/01/02 职场文书
幼儿园综治宣传月活动总结
2015/05/07 职场文书
篮球比赛通讯稿
2015/07/18 职场文书