keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python面向对象特殊成员
Apr 24 Python
Python开发的十个小贴士和技巧及长常犯错误
Sep 27 Python
python tkinter canvas 显示图片的示例
Jun 13 Python
Django forms表单 select下拉框的传值实例
Jul 19 Python
Python 转换RGB颜色值的示例代码
Oct 13 Python
python读取ini配置文件过程示范
Dec 23 Python
Python实现仿射密码的思路详解
Apr 23 Python
使用Keras预训练好的模型进行目标类别预测详解
Jun 27 Python
python利用蒙版抠图(使用PIL.Image和cv2)输出透明背景图
Aug 04 Python
Django配置跨域并开发测试接口
Nov 04 Python
用python自动生成日历
Apr 24 Python
在Python中如何使用yield
Jun 07 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
在同一窗体中使用PHP来处理多个提交任务
2008/05/08 PHP
php遍历文件夹所有文件子文件夹函数代码
2013/11/27 PHP
php文件扩展名判断及获取文件扩展名的N种方法
2015/09/12 PHP
PHP读取XML格式文件的方法总结
2017/02/27 PHP
PHP 断点续传实例详解
2017/11/11 PHP
Laravel框架之解决前端显示图片问题
2019/10/24 PHP
PHP实现抽奖功能实例代码
2020/06/30 PHP
javascript OFFICE控件测试代码
2009/12/08 Javascript
javascript object array方法使用详解
2012/12/03 Javascript
jQuery写fadeTo示例代码
2014/02/21 Javascript
七个很有意思的PHP函数
2014/05/12 Javascript
Javascript window对象详解
2014/11/12 Javascript
jQuery中filter()方法用法实例
2015/01/06 Javascript
使用JavaScript刷新网页的方法
2015/06/04 Javascript
深入理解JavaScript中的for循环
2017/02/07 Javascript
Vue2 配置 Axios api 接口调用文件的方法
2017/11/13 Javascript
jquery的 filter()方法使用教程
2018/03/22 jQuery
vue单页面实现当前页面刷新或跳转时提示保存
2018/11/02 Javascript
JavaScript实现公告栏上下滚动效果
2020/03/13 Javascript
Openlayers实现地图的基本操作
2020/09/28 Javascript
python的继承知识点总结
2018/12/10 Python
Python 3.3实现计算两个日期间隔秒数/天数的方法示例
2019/01/07 Python
python 读写excel文件操作示例【附源码下载】
2019/06/19 Python
python将字符串转换成json的方法小结
2019/07/09 Python
关于tensorflow的几种参数初始化方法小结
2020/01/04 Python
Python读取表格类型文件代码实例
2020/02/17 Python
解决jupyter notebook打不开无反应 浏览器未启动的问题
2020/04/10 Python
python 元组的使用方法
2020/06/09 Python
javascript实现用户必须勾选协议实例讲解
2021/03/24 Javascript
学生打架检讨书大全
2014/01/23 职场文书
银行贷款承诺书
2014/03/29 职场文书
合同协议书格式
2014/04/18 职场文书
2014年党课学习材料
2014/05/11 职场文书
《水上飞机》教学反思
2016/02/20 职场文书
品牌形象定位,全面分析
2019/07/23 职场文书
解决Django transaction进行事务管理踩过的坑
2021/04/24 Python