在tensorflow实现直接读取网络的参数(weight and bias)的值


Posted in Python onJune 24, 2020

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。

下面介绍如何直接读取网络的weight 和 bias。

(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
 full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) ##一定要先初始化变量
  print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

补充知识:如何在 PyTorch 中设定学习率衰减(learning rate decay)

在tensorflow实现直接读取网络的参数(weight and bias)的值

很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:

def adjust_learning_rate(optimizer, epoch):
 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
 lr = args.lr * (0.1 ** (epoch // 30))
 for param_group in optimizer.param_groups:
  param_group['lr'] = lr

什么是param_groups?

optimizer通过param_group来管理参数组.param_group中保存了参数组及其对应的学习率,动量等等.所以我们可以通过更改param_group[‘lr']的值来更改对应参数组的学习率。

# 有两个`param_group`即,len(optim.param_groups)==2
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
   ], lr=1e-2, momentum=0.9)
 
#一个参数组
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

以上这篇在tensorflow实现直接读取网络的参数(weight and bias)的值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现web方式logview的方法
Aug 10 Python
基于Python 的进程管理工具supervisor使用指南
Sep 18 Python
Python正则表达式完全指南
May 25 Python
详解TensorFlow在windows上安装与简单示例
Mar 05 Python
python实现文本界面网络聊天室
Dec 12 Python
python如何实现视频转代码视频
Jun 17 Python
python 进程间数据共享multiProcess.Manger实现解析
Sep 23 Python
python使用ctypes调用扩展模块的实例方法
Jan 28 Python
keras实现多种分类网络的方式
Jun 11 Python
利用Python实现Json序列化库的方法步骤
Sep 09 Python
PyCharm 2020.1版安装破解注册码永久激活(激活到2089年)
Sep 24 Python
python线程池 ThreadPoolExecutor 的用法示例
Oct 10 Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
pytorch中的weight-initilzation用法
Jun 24 #Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
You might like
php 判断网页是否是utf8编码的方法
2014/06/06 PHP
php批量删除cookie的简单实现方法
2015/01/26 PHP
动手学习无线电
2021/03/10 无线电
jQuery EasyUI API 中文文档 - EasyLoader 加载器
2011/09/29 Javascript
uploadify在Firefox下丢失session问题的解决方法
2013/08/07 Javascript
理解JavaScript原型链
2016/10/25 Javascript
jQuery获取Table某列的值(推荐)
2017/03/03 Javascript
用element的upload组件实现多图片上传和压缩的示例代码
2019/02/12 Javascript
五分钟搞懂Vuex实用知识(小结)
2019/08/12 Javascript
小程序实现多个选项卡切换
2020/06/19 Javascript
[15:28]DOTA2 HEROS教学视频教你分分钟做大人-剧毒术士
2014/06/13 DOTA
[01:12]DOTA2次级职业联赛 - Newbee.Y 战队宣传片
2014/12/01 DOTA
[03:02]2020完美世界城市挑战赛(秋季赛)总决赛回顾
2021/03/11 DOTA
Python 网页解析HTMLParse的实例详解
2017/08/10 Python
python实现逆序输出一个数字的示例讲解
2018/06/25 Python
Python WEB应用部署的实现方法
2019/01/02 Python
python实现多层感知器
2019/01/18 Python
浅谈Python批处理文件夹中的txt文件
2019/03/11 Python
Python中的self用法详解
2019/08/06 Python
python用类实现文章敏感词的过滤方法示例
2019/10/27 Python
python zip()函数使用方法解析
2019/10/31 Python
python实现把两个二维array叠加成三维array示例
2019/11/29 Python
Python-openCV读RGB通道图实例
2020/01/17 Python
对Tensorflow中Device实例的生成和管理详解
2020/02/04 Python
pandas数据处理之绘图的实现
2020/06/15 Python
python 读取、写入txt文件的示例
2020/09/27 Python
python wsgiref源码解析
2021/02/06 Python
Django和Ueditor自定义存储上传文件的文件名
2021/02/25 Python
CSS3 border-image详解、应用及jQuery插件
2011/08/29 HTML / CSS
详解HTML5通讯录获取指定多个人的信息
2016/12/20 HTML / CSS
Speedo速比涛德国官方网站:世界领先的泳装品牌
2019/08/26 全球购物
酒吧员工的岗位职责
2013/11/26 职场文书
个人充满哲理的自我评价
2014/02/20 职场文书
《他得的红圈圈最多》教学反思
2014/04/24 职场文书
就业证明函
2015/06/17 职场文书
原生JS实现分页
2022/04/19 Javascript