keras处理欠拟合和过拟合的实例讲解


Posted in Python onMay 25, 2020

baseline

import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()

baseline_history = baseline_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

小模型

small_model = keras.Sequential(
[
 layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(4, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

大模型

big_model = keras.Sequential(
[
 layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(512, activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)

绘图比较上述三个模型

def plot_history(histories, key='binary_crossentropy'):
 plt.figure(figsize=(16,10))
 
 for name, history in histories:
 val = plt.plot(history.epoch, history.history['val_'+key],
     '--', label=name.title()+' Val')
 plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
    label=name.title()+' Train')

 plt.xlabel('Epochs')
 plt.ylabel(key.replace('_',' ').title())
 plt.legend()

 plt.xlim([0,max(history.epoch)])


plot_history([('baseline', baseline_history),
    ('small', small_history),
    ('big', big_history)])

keras处理欠拟合和过拟合的实例讲解

三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象

大模型在训练集上表现更好,过拟合的速度更快

l2正则减少过拟合

l2_model = keras.Sequential(
[
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), 
     activation='relu'),
 layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('l2', l2_history)])

keras处理欠拟合和过拟合的实例讲解

可以发现正则化之后的模型在验证集上的过拟合程度减少

添加dropout减少过拟合

dpt_model = keras.Sequential(
[
 layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
 layers.Dropout(0.5),
 layers.Dense(16, activation='relu'),
 layers.Dropout(0.5),
 layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
      loss='binary_crossentropy',
      metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
          epochs=20, batch_size=512,
          validation_data=(test_data, test_labels),
          verbose=2)
plot_history([('baseline', baseline_history),
    ('dropout', dpt_history)])

keras处理欠拟合和过拟合的实例讲解

批正则化

model = keras.Sequential([
 layers.Dense(64, activation='relu', input_shape=(784,)),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(64, activation='relu'),
 layers.BatchNormalization(),
 layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()

总结

防止神经网络中过度拟合的最常用方法:

获取更多训练数据。

减少网络容量。

添加权重正规化。

添加dropout。

以上这篇keras处理欠拟合和过拟合的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现Mysql数据库连接池实例详解
Apr 11 Python
pytorch cnn 识别手写的字实现自建图片数据
May 20 Python
Python numpy实现二维数组和一维数组拼接的方法
Jun 05 Python
pandas读取csv文件,分隔符参数sep的实例
Dec 12 Python
Python进度条的制作代码实例
Aug 31 Python
python实现统计代码行数的小工具
Sep 19 Python
python多线程并发及测试框架案例
Oct 15 Python
Windows 下更改 jupyterlab 默认启动位置的教程详解
May 18 Python
Python 如何展开嵌套的序列
Aug 01 Python
python opencv实现简易画图板
Aug 27 Python
mac安装python3后使用pip和pip3的区别说明
Sep 01 Python
浅析python字符串前加r、f、u、l 的区别
Jan 24 Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
关于keras中keras.layers.merge的用法说明
May 23 #Python
使用keras2.0 将Merge层改为函数式
May 23 #Python
You might like
使用php实现从身份证中提取生日
2016/05/09 PHP
onsubmit阻止form表单提交与onclick的相关操作
2010/09/03 Javascript
ExtJs的Date格式字符代码
2010/12/30 Javascript
3种Jquery限制文本框只能输入数字字母的方法
2014/12/03 Javascript
javascript每日必学之运算符
2016/02/16 Javascript
Javascript highcharts 饼图显示数量和百分比实例代码
2016/12/06 Javascript
js实现导航栏中英文切换效果
2017/01/16 Javascript
浅谈js函数三种定义方式 & 四种调用方式 & 调用顺序
2017/02/19 Javascript
解决IE11 vue +webpack 项目中数据更新后页面没有刷新的问题
2018/09/25 Javascript
详解项目升级到vue-cli3的正确姿势
2019/01/28 Javascript
vue基础之模板和过滤器用法实例分析
2019/03/12 Javascript
javascript sort()对数组中的元素进行排序详解
2019/10/13 Javascript
python实现将英文单词表示的数字转换成阿拉伯数字的方法
2015/07/02 Python
将Python的Django框架与认证系统整合的方法
2015/07/24 Python
Windows系统下使用flup搭建Nginx和Python环境的方法
2015/12/25 Python
利用Python破解斗地主残局详解
2017/06/30 Python
CentOS 7下安装Python3.6 及遇到的问题小结
2018/11/08 Python
在python 不同时区之间的差值与转换方法
2019/01/14 Python
Django stark组件使用及原理详解
2019/08/22 Python
python对Excel按条件进行内容补充(推荐)
2019/11/24 Python
Python scrapy增量爬取实例及实现过程解析
2019/12/24 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
2020/09/17 Python
详解Django配置JWT认证方式
2020/05/09 Python
JavaScript+Canvas实现自定义画板的示例代码
2019/05/13 HTML / CSS
日语系毕业生推荐信
2013/11/11 职场文书
商场促销活动方案
2014/02/08 职场文书
医师定期考核实施方案
2014/05/07 职场文书
学校安全生产承诺书
2014/05/23 职场文书
父亲节活动策划方案
2014/08/24 职场文书
小学生优秀评语
2014/12/29 职场文书
幼儿园中秋节活动总结
2015/03/23 职场文书
开学第一周总结
2015/07/16 职场文书
教务处教学工作总结
2015/08/10 职场文书
预备党员表决心的话
2015/09/22 职场文书
省级三好学生主要事迹材料
2015/11/03 职场文书
带你了解Java中的ForkJoin
2022/04/28 Java/Android