tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 文件重命名工具代码
Jul 26 Python
python读写ini文件示例(python读写文件)
Mar 25 Python
python数据结构之二叉树的遍历实例
Apr 29 Python
python中enumerate函数用法实例分析
May 20 Python
Python通过90行代码搭建一个音乐搜索工具
Jul 29 Python
Python实现获取域名所用服务器的真实IP
Oct 25 Python
Python 专题四 文件基础知识
Mar 20 Python
python中使用iterrows()对dataframe进行遍历的实例
Jun 09 Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 Python
Python 绘制酷炫的三维图步骤详解
Jul 12 Python
python爬不同图片分别保存在不同文件夹中的实现
Apr 02 Python
详细总结Python常见的安全问题
May 21 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
PHP中return 和 exit 、break和contiue 区别与用法
2012/04/09 PHP
跨浏览器的 mouseenter mouseleave 以及 compareDocumentPosition的使用说明
2010/05/04 Javascript
javascript的console.log()用法小结
2012/05/31 Javascript
载入jQuery库的最佳方法详细说明及实现代码
2012/12/28 Javascript
非html5实现js版弹球游戏示例代码
2013/09/22 Javascript
JavaScript中使用Substring删除字符串最后一个字符
2013/11/03 Javascript
JavaScript异步加载浅析
2014/12/28 Javascript
IE7浏览器窗口大小改变事件执行多次bug及IE6/IE7/IE8下resize问题
2015/08/21 Javascript
详解Angularjs中的依赖注入
2016/03/11 Javascript
基于js中的原型、继承的一些想法
2016/08/10 Javascript
Javascript 调用 ActionScript 的简单方法
2016/09/22 Javascript
用jmSlip编写移动端顶部日历选择控件
2016/10/24 Javascript
写jQuery插件时的注意点
2017/02/20 Javascript
老生常谈javascript中逻辑运算符&&和||的返回值问题
2017/04/13 Javascript
js实现音乐播放控制条
2017/09/09 Javascript
Vue实现点击后文字变色切换方法
2018/02/11 Javascript
Vue之mixin全局的用法详解
2018/08/22 Javascript
JavaScript常用内置对象用法分析
2019/07/09 Javascript
JavaScript实现移动端弹窗后禁止滚动
2020/05/25 Javascript
Saltstack快速入门简单汇总
2016/03/01 Python
python语言基本语句用法总结
2019/06/11 Python
使用pytorch实现可视化中间层的结果
2019/12/30 Python
Python操作MongoDb数据库流程详解
2020/03/05 Python
Html5页面中的返回实现的方法
2018/02/26 HTML / CSS
俄语地区最大的中国商品在线购物网站之一:Umka Mall
2019/11/03 全球购物
绝对经典成功的大学生推荐信
2013/11/08 职场文书
外贸专业求职信
2014/03/09 职场文书
关爱留守儿童倡议书
2014/04/15 职场文书
小学六年级学生评语
2014/04/22 职场文书
新闻发布会策划方案
2014/06/12 职场文书
高等教育学专业自荐书
2014/06/17 职场文书
民主生活会整改措施(党员)
2014/09/18 职场文书
2014年小学美术工作总结
2014/12/20 职场文书
青年教师个人总结
2015/02/11 职场文书
超市食品安全承诺书
2015/04/29 职场文书
基于PyTorch实现一个简单的CNN图像分类器
2021/05/29 Python