keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现学校管理系统
Jan 11 Python
Django csrf 验证问题的实现
Oct 09 Python
python内存管理机制原理详解
Aug 12 Python
python数据处理之如何选取csv文件中某几行的数据
Sep 02 Python
python  logging日志打印过程解析
Oct 22 Python
Python基本类型的连接组合和互相转换方式(13种)
Dec 16 Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 Python
宝塔面板成功部署Django项目流程(图文)
Jun 22 Python
python将字典内容写入json文件的实例代码
Aug 12 Python
python3.5的包存放的具体路径
Aug 16 Python
浅谈Python __init__.py的作用
Oct 28 Python
python基于tkinter制作下班倒计时工具
Apr 28 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
zf框架的数据库追踪器使用示例
2014/03/13 PHP
用PHP代码在网页上生成图片
2015/07/01 PHP
JavaScript confirm选择判断
2008/10/18 Javascript
JavaScript 提升运行速度之循环篇 译文
2009/08/15 Javascript
让ie运行js时提示允许阻止内容运行的解决方法
2010/10/24 Javascript
jQuery EasyUI API 中文文档 - Tree树使用介绍
2011/11/19 Javascript
基于jquery创建的一个图片、视频缓冲的效果样式插件
2012/08/28 Javascript
如何设置iframe高度自适应在跨域情况下的可用方法
2013/09/06 Javascript
javascript中match函数的用法小结
2014/02/08 Javascript
PHP结合jQuery实现的评论顶、踩功能
2015/07/22 Javascript
使用CoffeeScrip优美方式编写javascript代码
2015/10/28 Javascript
javascript中checkbox使用方法实例演示
2015/11/19 Javascript
jQuery使用模式窗口实现在主页面和子页面中互相传值的方法
2016/03/01 Javascript
AngularJS包括详解及示例代码
2016/08/17 Javascript
详解webpack loader和plugin编写
2018/10/12 Javascript
详解使用WebPack搭建React开发环境
2019/08/06 Javascript
微信小程序分包加载代码实现方法详解
2019/09/23 Javascript
JavaScript实现猜数字游戏
2020/05/20 Javascript
Python 常用的安装Module方式汇总
2017/05/06 Python
对python中的pop函数和append函数详解
2018/05/04 Python
python 编写简单网页服务器的实例
2018/06/01 Python
便捷提取python导入包的属性方法
2018/10/15 Python
python 读取修改pcap包的例子
2019/07/23 Python
Python字典常见操作实例小结【定义、添加、删除、遍历】
2019/10/25 Python
selenium与xpath之获取指定位置的元素的实现
2021/01/26 Python
HTML5利用约束验证API来检查表单的输入数据的代码实例
2016/12/20 HTML / CSS
巴西服装和鞋子购物网站:Marisa
2018/10/25 全球购物
What's the difference between deep copy and shallow copy? (深拷贝与浅拷贝有什么区别)
2015/11/10 面试题
护士实习鉴定范文
2013/12/22 职场文书
影子教师研修方案
2014/06/14 职场文书
大学生党员自我剖析材料
2014/10/06 职场文书
酒店辞职书怎么写
2015/02/26 职场文书
如何在Python中创建二叉树
2021/03/30 Python
JavaScript实现酷炫的鼠标拖尾特效
2022/02/18 Javascript
剖析后OpLog订阅MongoDB的数据变更就没那么难了
2022/02/24 MongoDB
Win11安装升级时提示“该电脑必须支持安全启动”
2022/04/19 数码科技