关于tensorflow的几种参数初始化方法小结


Posted in Python onJanuary 04, 2020

在tensorflow中,经常会遇到参数初始化问题,比如在训练自己的词向量时,需要对原始的embeddigs矩阵进行初始化,更一般的,在全连接神经网络中,每层的权值w也需要进行初始化。

tensorlfow中应该有一下几种初始化方法

1. tf.constant_initializer() 常数初始化
2. tf.ones_initializer() 全1初始化
3. tf.zeros_initializer() 全0初始化
4. tf.random_uniform_initializer() 均匀分布初始化
5. tf.random_normal_initializer() 正态分布初始化
6. tf.truncated_normal_initializer() 截断正态分布初始化
7. tf.uniform_unit_scaling_initializer() 这种方法输入方差是常数
8. tf.variance_scaling_initializer() 自适应初始化
9. tf.orthogonal_initializer() 生成正交矩阵

具体的

1、tf.constant_initializer(),它的简写是tf.Constant()

#coding:utf-8
import numpy as np 
import tensorflow as tf 
train_inputs = [[1,2],[1,4],[3,2]]
with tf.variable_scope("embedding-layer"):
  val = np.array([[1,2,3,4,5,6,7],[1,3,4,5,2,1,9],[0,12,3,4,5,7,8],[2,3,5,5,6,8,9],[3,1,6,1,2,3,5]])
  const_init = tf.constant_initializer(val)
  embeddings = tf.get_variable("embed",shape=[5,7],dtype=tf.float32,initializer=const_init)
  embed = tf.nn.embedding_lookup(embeddings, train_inputs)             #在embedding中查找train_input所对应的表示
  print("embed",embed)
  sum_embed = tf.reduce_mean(embed,1)
initall = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(initall)
  print(sess.run(embed))
  print(sess.run(tf.shape(embed)))
  print(sess.run(sum_embed))

4、random_uniform_initializer = RandomUniform()

可简写为tf.RandomUniform()

生成均匀分布的随机数,参数有四个(minval=0, maxval=None, seed=None, dtype=dtypes.float32),分别用于指定最小值,最大值,随机数种子和类型。

6、tf.truncated_normal_initializer()

可简写tf.TruncatedNormal()

生成截断正态分布的随机数,这个初始化方法在tf中用得比较多。

它有四个参数(mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32),分别用于指定均值、标准差、随机数种子和随机数的数据类型,一般只需要设置stddev这一个参数就可以了。

8、tf.variance_scaling_initializer()

可简写为tf.VarianceScaling()

参数为(scale=1.0,mode="fan_in",distribution="normal",seed=None,dtype=dtypes.float32)

scale: 缩放尺度(正浮点数)

mode: "fan_in", "fan_out", "fan_avg"中的一个,用于计算标准差stddev的值。

distribution:分布类型,"normal"或“uniform"中的一个。

当 distribution="normal" 的时候,生成truncated normal distribution(截断正态分布) 的随机数,其中stddev = sqrt(scale / n) ,n的计算与mode参数有关。

如果mode = "fan_in", n为输入单元的结点数;

如果mode = "fan_out",n为输出单元的结点数;

如果mode = "fan_avg",n为输入和输出单元结点数的平均值。

当distribution="uniform”的时候 ,生成均匀分布的随机数,假设分布区间为[-limit, limit],则 limit = sqrt(3 * scale / n)

以上这篇关于tensorflow的几种参数初始化方法小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件和目录操作方法大全(含实例)
Mar 12 Python
python使用urllib2提交http post请求的方法
May 26 Python
Python内置函数OCT详解
Nov 09 Python
Pycharm在创建py文件时,自动添加文件头注释的实例
May 07 Python
Python简单爬虫导出CSV文件的实例讲解
Jul 06 Python
Sanic框架应用部署方法详解
Jul 18 Python
python获取磁盘号下盘符步骤详解
Jun 19 Python
Python-numpy实现灰度图像的分块和合并方式
Jan 09 Python
python中with用法讲解
Feb 07 Python
Python ArgumentParse的subparser用法说明
Apr 20 Python
Python实现密钥密码(加解密)实例详解
Apr 26 Python
Python 中如何写注释
Aug 28 Python
基于TensorFlow常量、序列以及随机值生成实例
Jan 04 #Python
Tensorflow 实现分批量读取数据
Jan 04 #Python
Tensorflow的常用矩阵生成方式
Jan 04 #Python
Tensorflow读取并输出已保存模型的权重数值方式
Jan 04 #Python
tensorflow实现打印ckpt模型保存下的变量名称及变量值
Jan 04 #Python
tensorflow 获取所有variable或tensor的name示例
Jan 04 #Python
tensorflow没有output结点,存储成pb文件的例子
Jan 04 #Python
You might like
基于mysql的论坛(2)
2006/10/09 PHP
PHP一些有意思的小区别
2006/12/06 PHP
php遍历数组的方法分享
2012/03/22 PHP
单台服务器的PHP进程之间实现共享内存的方法
2014/06/13 PHP
YII框架模块化处理操作示例
2019/04/26 PHP
分享几种好用的PHP自定义加密函数(可逆/不可逆)
2020/09/15 PHP
语义化 H1 标签
2008/01/14 Javascript
js 无提示关闭浏览器页面的代码
2010/03/09 Javascript
Jquery升级新版本后选择器的语法问题
2010/06/02 Javascript
浅析js预加载/延迟加载
2014/09/25 Javascript
实例讲解jquery与json的结合
2016/01/07 Javascript
浅谈Javascript中的Label语句
2016/12/14 Javascript
利用imgareaselect辅助后台实现图片上传裁剪
2017/03/02 Javascript
JS中判断字符串存在和非空的方法
2018/09/12 Javascript
详解vue配置后台接口方式
2019/03/29 Javascript
angular 表单验证器验证的同时限制输入的实现
2019/04/11 Javascript
Python数组条件过滤filter函数使用示例
2014/07/22 Python
python实现JAVA源代码从ANSI到UTF-8的批量转换方法
2015/08/10 Python
使用python编写简单的小程序编译成exe跑在win10上
2018/01/15 Python
解决pandas使用read_csv()读取文件遇到的问题
2018/06/15 Python
Python多线程编程之多线程加锁操作示例
2018/09/06 Python
Python玩转加密的技巧【推荐】
2019/05/13 Python
Python3的unicode编码转换成中文的问题及解决方案
2019/12/10 Python
最新PyCharm 2020.2.3永久激活码(亲测有效)
2020/11/26 Python
Hudson Jeans官网:高级精制牛仔裤
2018/11/28 全球购物
美国地毯购买网站:Rugs USA
2019/02/23 全球购物
shell程序中如何注释
2012/01/28 面试题
农业局学习党的群众路线教育实践活动心得体会
2014/03/07 职场文书
交警个人先进事迹材料
2014/05/11 职场文书
2014年安全生产责任书
2014/07/22 职场文书
我的中国梦演讲稿800字
2014/08/19 职场文书
小学生读书笔记
2015/07/01 职场文书
pytorch实现手写数字图片识别
2021/05/20 Python
Python合并多张图片成PDF
2021/06/09 Python
Nginx部署vue项目和配置代理的问题解析
2021/08/04 Servers
详解使用内网穿透工具Ngrok代理本地服务
2022/03/31 Servers