使用Keras实现简单线性回归模型操作


Posted in Python onJune 12, 2020

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

使用Keras实现简单线性回归模型操作

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

使用Keras实现简单线性回归模型操作

二、完整代码

import numpy as np
np.random.seed(1337) 
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
 
# 生成数据
X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
np.random.shuffle(X) # 打乱顺序
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
# plot
plt.scatter(X, Y)
plt.show()
 
X_train, Y_train = X[:160], Y[:160]  # 前160组数据为训练数据集
X_test, Y_test = X[160:], Y[160:]  #后40组数据为测试数据集
 
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=1, units=1))
 
# 选定loss函数和优化器
model.compile(loss='mse', optimizer='sgd')
 
# 训练过程
print('Training -----------')
for step in range(501):
 cost = model.train_on_batch(X_train, Y_train)
 if step % 50 == 0:
  print("After %d trainings, the cost: %f" % (step, cost))
 
# 测试过程
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
 
# 将训练结果绘出
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

以上这篇使用Keras实现简单线性回归模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的字符串操作和编码Unicode详解
Jan 18 Python
Python 的类、继承和多态详解
Jul 16 Python
pandas 使用apply同时处理两列数据的方法
Apr 20 Python
Python图像处理之识别图像中的文字(实例讲解)
May 10 Python
Selenium(Python web测试工具)基本用法详解
Aug 10 Python
Python机器学习算法库scikit-learn学习之决策树实现方法详解
Jul 04 Python
python解析yaml文件过程详解
Aug 30 Python
tensorflow中tf.reduce_mean函数的使用
Apr 19 Python
opencv 查找连通区域 最大面积实例
Jun 04 Python
DataFrame 数据合并实现(merge,join,concat)
Jun 14 Python
Python 合并拼接字符串的方法
Jul 28 Python
关于pycharm 切换 python3.9 报错 ‘HTMLParser‘ object has no attribute ‘unescape‘ 的问题
Nov 24 Python
Python实现Keras搭建神经网络训练分类模型教程
Jun 12 #Python
简单了解Python变量作用域正确使用方法
Jun 12 #Python
keras 读取多标签图像数据方式
Jun 12 #Python
Python数据可视化图实现过程详解
Jun 12 #Python
浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别
Jun 12 #Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 #Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 #Python
You might like
php与php MySQL 之间的关系
2009/07/17 PHP
CodeIgniter与PHP5.6的兼容问题
2015/07/16 PHP
jQuery 打造动态渐变按钮 详细图文教程
2010/04/25 Javascript
MC Dialog js弹出层 完美兼容多浏览器(5.6更新)
2010/05/06 Javascript
Javascript事件热键兼容ie|firefox
2010/12/30 Javascript
js实现图片放大缩小功能后进行复杂排序的方法
2012/11/08 Javascript
jquery改变tr背景色的示例代码
2013/12/28 Javascript
jQuery实现仿百度首页滑动伸缩展开的添加服务效果代码
2015/09/09 Javascript
js实时获取窗口大小变化的实例代码
2016/11/18 Javascript
无阻塞加载js,防止因js加载不了影响页面显示的问题
2016/12/18 Javascript
angular.js+node.js实现下载图片处理详解
2017/03/31 Javascript
详解nodejs express下使用redis管理session
2017/04/24 NodeJs
用vue和node写的简易购物车实现
2017/04/25 Javascript
详解解决小程序中webview页面多层history返回问题
2019/08/20 Javascript
jquery+css3实现的经典弹出层效果示例
2020/05/16 jQuery
如何在微信小程序中使用骨架屏的步骤
2020/06/12 Javascript
Vue+Bootstrap收藏(点赞)功能逻辑与具体实现
2020/10/22 Javascript
[40:57]TI4 循环赛第二日 iG vs EG
2014/07/11 DOTA
python获取图片颜色信息的方法
2015/03/18 Python
Python中利用函数装饰器实现备忘功能
2015/03/30 Python
Python脚本获取操作系统版本信息
2016/12/17 Python
python中pika模块问题的深入探究
2018/10/13 Python
Pandas透视表(pivot_table)详解
2019/07/22 Python
python使用sklearn实现决策树的方法示例
2019/09/12 Python
python连接mysql有哪些方法
2020/06/24 Python
Django 实现图片上传和下载功能
2020/12/31 Python
CSS3中的元素过渡属性transition示例详解
2016/11/30 HTML / CSS
产品质量承诺书
2014/03/27 职场文书
初中毕业典礼演讲稿
2014/09/09 职场文书
2014年小学生教师节演讲稿范文
2014/09/10 职场文书
党员个人剖析材料2014
2014/10/08 职场文书
2014村党支部书记党建工作汇报材料
2014/11/02 职场文书
英文感谢信格式
2015/01/21 职场文书
2015年化验室工作总结
2015/04/23 职场文书
2015秋季小学开学寄语
2015/05/27 职场文书
IDEA2021.2配置docker如何将springboot项目打成镜像一键发布部署
2021/09/25 Java/Android