用tensorflow实现弹性网络回归算法


Posted in Python onJanuary 09, 2018

本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下

python代码:

#用tensorflow实现弹性网络算法(多变量) 
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。 
 
 
#1 导入必要的编程库,创建计算图,加载数据集 
import matplotlib.pyplot as plt 
import tensorflow as tf 
import numpy as np 
from sklearn import datasets 
from tensorflow.python.framework import ops 
 
ops.get_default_graph() 
sess = tf.Session() 
iris = datasets.load_iris() 
 
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data]) 
y_vals = np.array([y[0] for y in iris.data]) 
 
 
#2 声明学习率,批量大小,占位符和模型变量,模型输出 
learning_rate = 0.001 
batch_size = 50 
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3 
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) 
A = tf.Variable(tf.random_normal(shape=[3,1])) 
b = tf.Variable(tf.random_normal(shape=[1,1])) 
model_output = tf.add(tf.matmul(x_data, A), b) 
 
 
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则 
elastic_param1 = tf.constant(1.) 
elastic_param2 = tf.constant(1.) 
l1_a_loss = tf.reduce_mean(abs(A)) 
l2_a_loss = tf.reduce_mean(tf.square(A)) 
e1_term = tf.multiply(elastic_param1, l1_a_loss) 
e2_term = tf.multiply(elastic_param2, l2_a_loss) 
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0) 
 
 
 
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数 
init = tf.global_variables_initializer() 
sess.run(init) 
my_opt = tf.train.GradientDescentOptimizer(learning_rate) 
train_step = my_opt.minimize(loss) 
 
loss_vec = [] 
for i in range(1000): 
   rand_index = np.random.choice(len(x_vals), size=batch_size) 
   rand_x = x_vals[rand_index] 
   rand_y = np.transpose([y_vals[rand_index]]) 
   sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y}) 
   temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y}) 
   loss_vec.append(temp_loss) 
   if (i+1)%250 == 0: 
     print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b))) 
     print('Loss= ' +str(temp_loss)) 
      
 
#现在能观察到, 随着训练迭代后损失函数已收敛。 
plt.plot(loss_vec, 'k--') 
plt.title('Loss per Generation') 
plt.xlabel('Generation') 
plt.ylabel('Loss') 
plt.show()

本文参考书《Tensorflow机器学习实战指南》

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现RSA加密(解密)算法
Feb 17 Python
Python中用字符串调用函数或方法示例代码
Aug 04 Python
Python编程之基于概率论的分类方法:朴素贝叶斯
Nov 11 Python
Python金融数据可视化汇总
Nov 17 Python
python使用多进程的实例详解
Sep 19 Python
python调用c++ ctype list传数组或者返回数组的方法
Feb 13 Python
Python实现Selenium自动化Page模式
Jul 14 Python
python如何删除文件中重复的字段
Jul 16 Python
pytorch梯度剪裁方式
Feb 04 Python
Python sklearn库实现PCA教程(以鸢尾花分类为例)
Feb 24 Python
Python通过文本和图片生成词云图
May 21 Python
selenium+python实现基本自动化测试的示例代码
Jan 27 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 #Python
python matplotlib 注释文本箭头简单代码示例
Jan 08 #Python
Python自定义简单图轴简单实例
Jan 08 #Python
[原创]python爬虫(入门教程、视频教程)
Jan 08 #Python
小米5s微信跳一跳小程序python源码
Jan 08 #Python
Python实现判断字符串中包含某个字符的判断函数示例
Jan 08 #Python
Python实现的字典值比较功能示例
Jan 08 #Python
You might like
德生S2000电路分析
2021/03/02 无线电
PHP中CURL方法curl_setopt()函数的参数分享
2013/01/19 PHP
php下获取http状态的实现代码
2014/05/09 PHP
PHP微框架Dispatch简介
2014/06/12 PHP
php+MySql实现登录系统与输出浏览者信息功能
2016/07/01 PHP
表单的焦点顺序tabindex和对应enter键提交
2013/01/04 Javascript
JavaScript中扩展Array contains方法实例
2020/08/23 Javascript
Node.js开发者必须了解的4个JS要点
2016/02/21 Javascript
js滑动提示效果代码分享
2016/03/10 Javascript
js贪吃蛇游戏实现思路和源码
2016/04/14 Javascript
浅谈js中的延迟执行和定时执行
2016/05/31 Javascript
jQuery导航条固定定位效果实例代码
2017/05/26 jQuery
js 获取元素的具体样式信息getcss(实例讲解)
2017/07/05 Javascript
详解js几个绕不开的事件兼容写法
2017/08/30 Javascript
基于 Immutable.js 实现撤销重做功能的实例代码
2018/03/01 Javascript
vue如何在自定义组件中使用v-model
2018/05/14 Javascript
bootstrap table插件动态加载表头
2019/07/19 Javascript
js实现图片实时时钟
2020/01/15 Javascript
web.py获取上传文件名的正确方法
2014/08/26 Python
在ironpython中利用装饰器执行SQL操作的例子
2015/05/02 Python
Python实现聊天机器人的示例代码
2018/07/09 Python
python提取包含关键字的整行数据方法
2018/12/11 Python
Python使用贪婪算法解决问题
2019/10/22 Python
使用python 的matplotlib 画轨道实例
2020/01/19 Python
Django后端分离 使用element-ui文件上传方式
2020/07/12 Python
Python如何合并多个字典或映射
2020/07/24 Python
Python实现文件压缩和解压的示例代码
2020/08/12 Python
澳大利亚电商Catch新西兰站:Catch.co.nz
2020/05/30 全球购物
大学生职业生涯规划方案
2014/01/03 职场文书
解除劳动合同证明书
2014/09/26 职场文书
导游词开场白
2015/01/31 职场文书
财务部岗位职责
2015/02/03 职场文书
情况说明书怎么写
2015/10/08 职场文书
为什么mysql字段要使用NOT NULL
2021/05/13 MySQL
Python中22个万用公式的小结
2021/07/21 Python
MySQL分布式恢复进阶
2022/07/23 MySQL