Keras实现支持masking的Flatten层代码


Posted in Python onJune 16, 2020

不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。

Keras原本Flatten的实现

class Flatten(Layer):
 def __init__(self, **kwargs):
  super(Flatten, self).__init__(**kwargs)
  self.input_spec = InputSpec(min_ndim=3)

 def compute_output_shape(self, input_shape):
  if not all(input_shape[1:]):
   raise ValueError('The shape of the input to "Flatten" '
        'is not fully defined '
        '(got ' + str(input_shape[1:]) + '. '
        'Make sure to pass a complete "input_shape" '
        'or "batch_input_shape" argument to the first '
        'layer in your model.')
  return (input_shape[0], np.prod(input_shape[1:]))

 def call(self, inputs):
  return K.batch_flatten(inputs)

自定义支持masking的实现

事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。

from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
import numpy as np

class MyFlatten(Layer):
 def __init__(self, **kwargs):
  self.supports_masking = True
  super(MyFlatten, self).__init__(**kwargs)

 def compute_mask(self, inputs, mask=None):
  if mask==None:
   return mask
  return K.batch_flatten(mask)

 def call(self, inputs, mask=None):
  return K.batch_flatten(inputs)

 def compute_output_shape(self, input_shape):
  return (input_shape[0], np.prod(input_shape[1:]))

正确性检验

from keras.layers import *
from keras.models import Model
from MyFlatten import MyFlatten
from MySumLayer import MySumLayer
from keras.initializers import ones

data = [[1,0,0,0],
  [1,2,0,0],
  [1,2,3,0],
  [1,2,3,4]]

A = Input(shape=[4]) # None * 4
emb = Embedding(5, 3, mask_zero=True, embeddings_initializer=ones())(A) # None * 4 * 3
fla = MyFlatten()(emb) # None * 12
out = MySumLayer(axis=1)(fla) # None * 1

model = Model(inputs=[A], outputs=[out])
print model.predict(data)

输出:

[ 3. 6. 9. 12.]

补充知识:pytorch中的reshape()、view()、transpose()和flatten()

1、torch.reshape()

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用

其作用是在不改变tensor元素数目的情况下改变tensor的shape

import torch
import numpy as np
a = np.arange(24)
b = a.reshape(4,3,2)
print(np.shape(a))
print(b,np.shape(b))

'''结果
(24,)
[[[ 0 1]
 [ 2 3]
 [ 4 5]]

 [[ 6 7]
 [ 8 9]
 [10 11]]

 [[12 13]
 [14 15]
 [16 17]]

 [[18 19]
 [20 21]
 [22 23]]] (4, 3, 2)
'''

2、view()

view()只可以由torch.Tensor.view()来调用

view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。

3、transpose()

torch.transpose(input, dim0, dim1) -> Tensor

将输入数据input的第dim0维和dim1维进行交换

#官方例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.9068, 1.8803, -0.5021],
  [-0.6576, 0.6334, -0.8961]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.9068, -0.6576],
  [ 1.8803, 0.6334],
  [-0.5021, -0.8961]])

4、flatten()

torch.flatten()的输入是tensor

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,

#官方例子
>>> t = torch.tensor([[[1, 2],
        [3, 4]],
        [[5, 6],
        [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
  [5, 6, 7, 8]])

torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,

>>> m = torch.nn.Sequential(
 torch.nn.Conv2d(1, 32, 5, 1, 1),
 torch.nn.Flatten(),
 torch.nn.Linear(160,10))
>>> m
Sequential(
 (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
 (1): Flatten()
 (2): Linear(in_features=160, out_features=10, bias=True)
)

以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中下划线的使用方法
Mar 27 Python
Python编程入门的一些基本知识
May 13 Python
Python3指定路径寻找符合匹配模式文件
May 22 Python
使用Python的Bottle框架写一个简单的服务接口的示例
Aug 25 Python
Python数据持久化shelve模块用法分析
Jun 29 Python
Django框架 querySet功能解析
Sep 04 Python
Python3内置函数chr和ord实现进制转换
Jun 05 Python
Keras—embedding嵌入层的用法详解
Jun 10 Python
python 实现朴素贝叶斯算法的示例
Sep 30 Python
python绘制汉诺塔
Mar 01 Python
python实现会员信息管理系统(List)
Mar 18 Python
解决IDEA翻译插件Translation报错更新TTK失败不能使用
Apr 24 Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
You might like
一个目录遍历函数
2006/10/09 PHP
PHP中上传多个文件的表单设计例子
2014/11/19 PHP
php+ajax实现无刷新数据分页的办法
2015/11/02 PHP
微信公众号开发之文本消息自动回复php代码
2016/08/08 PHP
thinkphp中U方法按路由规则生成url的方法
2018/03/12 PHP
PHP Beanstalkd消息队列的安装与使用方法实例详解
2020/02/21 PHP
asp 取文本框名称代码
2008/12/02 Javascript
用jQuery扩展自写的 UI导航
2010/01/13 Javascript
jquery鼠标停止移动事件
2013/12/21 Javascript
javascript实现验证身份证号的有效性并提示
2015/04/30 Javascript
JS控制按钮10秒钟后可用的方法
2015/12/22 Javascript
全面解析JavaScript中“&&”和“||”操作符(总结篇)
2016/07/18 Javascript
jQuery插件Easyui设置datagrid的pageNumber导致两次请求问题的解决方法
2016/08/06 Javascript
JS IOS/iPhone的Safari浏览器不兼容Javascript中的Date()问题如何解决
2016/11/11 Javascript
从零开始学习Node.js系列教程五:服务器监听方法示例
2017/04/13 Javascript
vue slot 在子组件中显示父组件传递的模板
2018/03/02 Javascript
微信web端后退强制刷新功能的实现代码
2018/03/04 Javascript
react router4+redux实现路由权限控制的方法
2018/05/03 Javascript
解决node修改后需频繁手动重启的问题
2018/05/13 Javascript
对Vue.js之事件的绑定(v-on: 或者 @ )详解
2018/09/15 Javascript
uniapp电商小程序实现订单30分钟倒计时
2020/11/01 Javascript
用Python写冒泡排序代码
2016/04/12 Python
Python高级特性切片(Slice)操作详解
2018/09/27 Python
python使用递归的方式建立二叉树
2019/07/03 Python
多版本python的pip 升级后, pip2 pip3 与python版本失配解决方法
2019/09/11 Python
python通过SSH登陆linux并操作的实现
2019/10/10 Python
Python基于requests库爬取网站信息
2020/03/02 Python
详解python环境安装selenium和手动下载安装selenium的方法
2020/03/17 Python
怎么写好自荐信
2013/10/30 职场文书
如何写自我鉴定
2014/03/19 职场文书
学习方法演讲稿
2014/05/10 职场文书
企业宣传策划方案
2014/05/29 职场文书
植树节标语
2014/06/27 职场文书
迎新春趣味活动方案
2014/08/24 职场文书
运动会广播稿200字
2014/10/18 职场文书
PHP中多字节字符串操作实例详解
2021/08/23 PHP