Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之编写类之三子类
Oct 11 Python
Python读写Excel文件方法介绍
Nov 22 Python
Python使用SQLite和Excel操作进行数据分析
Jan 20 Python
对python操作kafka写入json数据的简单demo分享
Dec 27 Python
Python获取网段内ping通IP的方法
Jan 31 Python
Python3内置模块之json编解码方法小结【推荐】
Dec 09 Python
python读取.mat文件的数据及实例代码
Jul 12 Python
基于python实现微信好友数据分析(简单)
Feb 16 Python
Pytest mark使用实例及原理解析
Feb 22 Python
Python打包工具PyInstaller的安装与pycharm配置支持PyInstaller详细方法
Feb 27 Python
使用sublime text3搭建Python编辑环境的实现
Jan 12 Python
Python中requests库的用法详解
Jun 05 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
Discuz Uchome ajaxpost小技巧
2011/01/04 PHP
简单谈谈PHP vs Node.js
2015/07/17 PHP
PHP下载远程图片并保存到本地方法总结
2016/01/22 PHP
PHP对象相关知识总结
2017/04/09 PHP
PHP单文件上传原理及上传函数的封装操作示例
2019/09/02 PHP
基于jQuery插件实现点击小图显示大图效果
2016/05/11 Javascript
JavaScript知识点总结(十六)之Javascript闭包(Closure)代码详解
2016/05/31 Javascript
node+experss实现爬取电影天堂爬虫
2016/11/20 Javascript
80%应聘者都不及格的JS面试题
2017/03/21 Javascript
在一般处理程序(ashx)中弹出js提示语
2017/08/16 Javascript
input 标签实现输入框带提示文字效果(两种方法)
2017/10/09 Javascript
详解自定义ajax支持跨域组件封装
2018/02/08 Javascript
深入解析koa之中间件流程控制
2019/06/17 Javascript
JS实现返回上一页并刷新页面的方法分析
2019/07/16 Javascript
微信小程序实现收货地址左滑删除
2020/11/18 Javascript
微信小程序实现点击图片放大预览
2019/10/21 Javascript
VueCli4项目配置反向代理proxy的方法步骤
2020/05/17 Javascript
关于angular引入ng-zorro的问题浅析
2020/09/09 Javascript
[02:06]2018完美世界全国高校联赛秋季赛开始报名(附彩蛋)
2018/09/03 DOTA
ubuntu系统下 python链接mysql数据库的方法
2017/01/09 Python
python实现画圆功能
2018/01/25 Python
python3+PyQt5实现自定义分数滑块部件
2018/04/24 Python
正确理解Python中if __name__ == '__main__'
2019/01/24 Python
Python 如何优雅的将数字转化为时间格式的方法
2019/09/26 Python
Python opencv相机标定实现原理及步骤详解
2020/04/09 Python
Python Unittest原理及基本使用方法
2020/11/06 Python
详解Python中string模块除去Str还剩下什么
2020/11/30 Python
利于python脚本编写可视化nmap和masscan的方法
2020/12/29 Python
python re模块常见用法例举
2021/03/01 Python
街头时尚在线:JESSICABUURMAN
2019/06/16 全球购物
YSL圣罗兰美妆英国官网:Yves Saint Laurent Beauty UK
2019/08/03 全球购物
应用化学专业职业生涯规划书
2013/12/31 职场文书
认购协议书范本
2014/04/22 职场文书
党员创先争优心得体会
2014/09/11 职场文书
优秀员工事迹材料
2014/12/20 职场文书
pycharm2021激活码使用教程(永久激活亲测可用)
2021/03/30 Python