对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解


Posted in Python onFebruary 11, 2020

在用tensorflow做一维的卷积神经网络的时候会遇到tf.nn.conv1d和layers.conv1d这两个函数,但是这两个函数有什么区别呢,通过计算得到一些规律。

1.关于tf.nn.conv1d的解释,以下是Tensor Flow中关于tf.nn.conv1d的API注解:

Computes a 1-D convolution given 3-D input and filter tensors.

Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NHWC", or [batch, in_channels, in_width] if data_format is "NCHW", and a filter / kernel tensor of shape [filter_width, in_channels, out_channels], this op reshapes the arguments to pass them to conv2d to perform the equivalent convolution operation.

Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. For example, if `data_format` does not start with "NC", a tensor of shape [batch, in_width, in_channels] is reshaped to [batch, 1, in_width, in_channels], and the filter is reshaped to [1, filter_width, in_channels, out_channels]. The result is then reshaped back to [batch, out_width, out_channels] whereoutwidthisafunctionofthestrideandpaddingasinconv2dwhereoutwidthisafunctionofthestrideandpaddingasinconv2d and returned to the caller.

Args: value: A 3D `Tensor`. Must be of type `float32` or `float64`. filters: A 3D `Tensor`. Must have the same type as `input`. stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. data_format: An optional `string` from `"NHWC", "NCHW"`. Defaults to `"NHWC"`, the data is stored in the order of [batch, in_width, in_channels]. The `"NCHW"` format stores data as [batch, in_channels, in_width]. name: A name for the operation (optional).

Returns:

A `Tensor`. Has the same type as input.

Raises:

ValueError: if `data_format` is invalid.

什么意思呢?就是说conv1d的参数含义:(以NHWC格式为例,即,通道维在最后)

1、value:在注释中,value的格式为:[batch, in_width, in_channels],batch为样本维,表示多少个样本,in_width为宽度维,表示样本的宽度,in_channels维通道维,表示样本有多少个通道。 事实上,也可以把格式看作如下:[batch, 行数, 列数],把每一个样本看作一个平铺开的二维数组。这样的话可以方便理解。

2、filters:在注释中,filters的格式为:[filter_width, in_channels, out_channels]。按照value的第二种看法,filter_width可以看作每次与value进行卷积的行数,in_channels表示value一共有多少列(与value中的in_channels相对应)。out_channels表示输出通道,可以理解为一共有多少个卷积核,即卷积核的数目。

3、stride:一个整数,表示步长,每次(向下)移动的距离(TensorFlow中解释是向右移动的距离,这里可以看作向下移动的距离)。

4、padding:同conv2d,value是否需要在下方填补0。

5、name:名称。可省略。

首先从参数列表可以看出value指的输入的数据,stride就是卷积的步长,这里我们最有疑问的就是filters这个参数,那么我们对filter进行简单的说明。从上面可以看到filters的格式为:[filter_width, in_channels, out_channels],这是一个数组的维度,对应的是卷积核的大小,输入的channel的格式,和卷积核的个数,下面我们用例子说明问题:

import tensorflow as tf
import numpy as np
 
 
if __name__ == '__main__':
  inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1])
  w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1])
  # filter width, filter channels and out channels(number of kernels)
  cov1 = tf.nn.conv1d(inputs, w, stride=1, padding='VALID')
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(cov1)
    print(out)

其输出为:

[[[ 5.],
    [ 8.],
    [11.],
    [14.]]]

我们分析一下,输入的数据为[[[1],[2],[3],[4],[5]]],有5个特征,分别对应的数值为1,2,3,4,5,那么经过卷积的结果为5,8,11,14,那么这个结果是怎么来的呢,我们根据卷积的计算,可以得到5 = 1*1 + 2*2, 8=2*1+ 3*2, 11 = 3*1+4*2, 14=4*1+5*2, 也就是W1=1, W2=2,正好和我们先面filters设置的数值相等,

w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1])

所以可以看到这个filtes设置的是是卷积核矩阵的,换句话说,卷积核矩阵我们是可以设置的。

2. 1.关于tf.layers.conv1d,函数的定义如下

tf.layers.conv1d(
 
inputs,
 
filters,
 
kernel_size,
 
strides=1,
 
padding='valid',
 
data_format='channels_last',
 
dilation_rate=1,
 
activation=None,
 
use_bias=True,
 
kernel_initializer=None,
 
bias_initializer=tf.zeros_initializer(),
 
kernel_regularizer=None,
 
bias_regularizer=None,
 
activity_regularizer=None,
 
kernel_constraint=None,
 
bias_constraint=None,
 
trainable=True,
 
name=None,
 
reuse=None
 
)

比较重要的几个参数是inputs, filters, kernel_size,下面分别说明

inputs : 输入tensor, 维度(None, a, b) 是一个三维的tensor

None : 一般是填充样本的个数,batch_size

a : 句子中的词数或者字数

b : 字或者词的向量维度

filters : 过滤器的个数

kernel_size : 卷积核的大小,卷积核其实应该是一个二维的,这里只需要指定一维,是因为卷积核的第二维与输入的词向量维度是一致的,因为对于句子而言,卷积的移动方向只能是沿着词的方向,即只能在列维度移动。一个例子:

import tensorflow as tf
import numpy as np
 
 
if __name__ == '__main__':
  inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1])
  cov2 = tf.layers.conv1d(inputs, filters=1, kernel_size=2, strides=1, padding='VALID')
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out = sess.run(cov2)
    print(out)

输出结果:

[[[-1.9953331]
 [-3.5520997]
 [-5.108866 ]
 [-6.6656327]]]

也许你得到的结果和我得到的结果不同,因为在这个函数里面只是设置了卷积核的尺寸和步长,没有设置具体的卷积核矩阵,所以这个卷积核矩阵是随机生成的,就会出现可能运行上面的程序出现不同结果的情况。

以上这篇对tensorflow中tf.nn.conv1d和layers.conv1d的区别详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python paramiko实现ssh远程访问的方法
Dec 03 Python
利用Python脚本在Nginx和uwsgi上部署MoinMoin的教程
May 05 Python
python绘制简单折线图代码示例
Dec 19 Python
python库lxml在linux和WIN系统下的安装
Jun 24 Python
Python Tkinter模块实现时钟功能应用示例
Jul 23 Python
Python3爬虫学习之将爬取的信息保存到本地的方法详解
Dec 12 Python
python二维码操作:对QRCode和MyQR入门详解
Jun 24 Python
Python OpenCV视频截取并保存实现代码
Nov 30 Python
解决Pytorch 加载训练好的模型 遇到的error问题
Jan 10 Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 Python
如何在scrapy中捕获并处理各种异常
Sep 28 Python
Python如何telnet到网络设备
Feb 18 Python
python中文分词库jieba使用方法详解
Feb 11 #Python
Transpose 数组行列转置的限制方式
Feb 11 #Python
Tensorflow:转置函数 transpose的使用详解
Feb 11 #Python
tensorflow多维张量计算实例
Feb 11 #Python
python误差棒图errorbar()函数实例解析
Feb 11 #Python
解决Python3.8用pip安装turtle-0.0.2出现错误问题
Feb 11 #Python
python scatter函数用法实例详解
Feb 11 #Python
You might like
浅析PHP的ASCII码转换类
2013/07/05 PHP
浅析iis7.5安装配置php环境
2015/05/10 PHP
PHP创建PowerPoint2007文档的方法
2015/12/10 PHP
给PHP开发者的编程指南 第一部分降低复杂程度
2016/01/18 PHP
PHP实现双链表删除与插入节点的方法示例
2017/11/11 PHP
PHP之认识(二)关于Traits的用法详解
2019/04/11 PHP
fromCharCode和charCodeAt 方法
2006/12/27 Javascript
Javascript中的数学函数集合
2007/05/08 Javascript
在IE,Firefox,Safari,Chrome,Opera浏览器上调试javascript
2008/12/02 Javascript
用JQuery 实现AJAX加载XML并解析的脚本
2009/07/25 Javascript
通过DOM脚本去设置样式信息
2010/09/19 Javascript
让AJAX不依赖后端接口实现方案
2012/12/03 Javascript
js字符串转成JSON
2013/11/07 Javascript
JavaScript父子窗体间的调用方法
2015/03/31 Javascript
JavaScript使用二分查找算法在数组中查找数据的方法
2015/04/07 Javascript
js实现鼠标点击文本框自动选中内容的方法
2015/08/20 Javascript
JS实现的在线调色板实例(附demo源码下载)
2016/03/01 Javascript
js原型链与继承解析(初体验)
2016/05/09 Javascript
JavaScript在form表单中使用button按钮实现submit提交方法
2017/01/23 Javascript
Next.js项目实战踩坑指南(笔记)
2018/11/29 Javascript
angular6的table组件开发的实现示例
2018/12/26 Javascript
vue-cli脚手架引入弹出层layer插件的几种方法
2019/06/24 Javascript
element表格翻页第2页从1开始编号(后端从0开始分页)
2019/12/10 Javascript
Vue elementui字体图标显示问题解决方案
2020/08/18 Javascript
Echarts.js无法引入问题解决方案
2020/10/30 Javascript
[54:30]Liquid vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
Python实现股市信息下载的方法
2015/06/15 Python
Python的collections模块中namedtuple结构使用示例
2016/07/07 Python
GitHub 热门:Python 算法大全,Star 超过 2 万
2019/04/29 Python
django自带调试服务器的使用详解
2019/08/29 Python
python如何查看网页代码
2020/06/07 Python
CSS3+DIV实现漂亮的动画彩色标签
2016/06/16 HTML / CSS
Madewell美德威尔美国官网:美国休闲服饰品牌
2016/11/25 全球购物
厂长助理岗位职责
2013/12/27 职场文书
行政求职信
2014/07/04 职场文书
2015年银行大堂经理工作总结
2015/04/24 职场文书