在keras中实现查看其训练loss值


Posted in Python onJune 16, 2020

想要查看每次训练模型后的 loss 值变化需要如下操作

loss_value= [ ]
self.history = model.fit(state,target_f,epochs=1, batch_size =32)
b = abs(float(self.history.history[‘loss'][0]))
loss_value.append(b)
print(loss_value)
loss_value = np.array( loss_value)
x = np.array(range(len( loss_value)))
plt.plot(x, loss_value, c = ‘g')
pt.svefit('c地址‘, dpi= 100)
plt.show()

scipy.sparse 稀疏矩阵 函数集合

pandas 用于在各种文件中提取,并处理分析数据; 有DataFrame数据结构,类似表格。

x=np.linspace(-10, 10, 100) 生成100个在-10到10之间的数组

补充知识:对keras训练过程中loss,val_loss,以及accuracy,val_accuracy的可视化

我就废话不多说了,大家还是直接看代码吧!

hist = model.fit_generator(generator=data_generator_reg(X=x_train, Y=[y_train_a,y_train_g], batch_size=batch_size),
         steps_per_epoch=train_num // batch_size,
         validation_data=(x_test, [y_test_a,y_test_g]),
         epochs=nb_epochs, verbose=1,
         workers=8, use_multiprocessing=True,
         callbacks=callbacks)

 logging.debug("Saving weights...")
 model.save_weights(os.path.join(db_name+"_models/"+save_name, save_name+'.h5'), overwrite=True)
 pd.DataFrame(hist.history).to_hdf(os.path.join(db_name+"_models/"+save_name, 'history_'+save_name+'.h5'), "history")

在训练时,会输出如下打印:

640/640 [==============================] - 35s 55ms/step - loss: 4.0216 - mean_absolute_error: 4.6525 - val_loss: 3.2888 - val_mean_absolute_error: 3.9109

有训练loss,训练预测准确度,以及测试loss,以及测试准确度,将文件保存后,使用下面的代码可以对训练以及评估进行可视化,下面有对应的参数名称:

loss,mean_absolute_error,val_loss,val_mean_absolute_error

import pandas as pd
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np

def get_args():
 parser = argparse.ArgumentParser(description="This script shows training graph from history file.")
 parser.add_argument("--input", "-i", type=str, required=True,
      help="path to input history h5 file")
 args = parser.parse_args()
 return args

def main():
 args = get_args()
 input_path = args.input

 df = pd.read_hdf(input_path, "history")
 print(np.min(df['val_mean_absolute_error']))
 input_dir = os.path.dirname(input_path)
 plt.plot(df["loss"], '-o', label="loss (age)", linewidth=2.0)
 plt.plot(df["val_loss"], '-o', label="val_loss (age)", linewidth=2.0)
 plt.xlabel("Number of epochs", fontsize=20)
 plt.ylabel("Loss", fontsize=20)
 plt.legend()
 plt.grid()
 plt.savefig(os.path.join(input_dir, "loss.pdf"), bbox_inches='tight', pad_inches=0)
 plt.cla()

 plt.plot(df["mean_absolute_error"], '-o', label="training", linewidth=2.0)
 plt.plot(df["val_mean_absolute_error"], '-o', label="validation", linewidth=2.0)
 ax = plt.gca()
 ax.set_ylim([2,13])
 ax.set_aspect(0.6/ax.get_data_ratio())
 plt.xticks(fontsize=20)
 plt.yticks(fontsize=20)
 plt.xlabel("Number of epochs", fontsize=20)
 plt.ylabel("Mean absolute error", fontsize=20)
 plt.legend(fontsize=20)
 plt.grid()
 plt.savefig(os.path.join(input_dir, "performance.pdf"), bbox_inches='tight', pad_inches=0)

if __name__ == '__main__':
 main()

以上这篇在keras中实现查看其训练loss值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python正则简单实例分析
Mar 21 Python
Python Socket实现简单TCP Server/client功能示例
Aug 05 Python
python中正则表达式的使用方法
Feb 25 Python
Django代码性能优化与Pycharm Profile使用详解
Aug 26 Python
python程序变成软件的实操方法
Jun 24 Python
利用Python的turtle库绘制玫瑰教程
Nov 23 Python
python实现opencv+scoket网络实时图传
Mar 20 Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
Jun 12 Python
Python学习之路安装pycharm的教程详解
Jun 17 Python
Django Model层F,Q对象和聚合函数原理解析
Nov 12 Python
Python学习开发之图形用户界面详解
Aug 23 Python
Python数据结构之队列详解
Mar 21 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
如何在Windows中安装多个python解释器
Jun 16 #Python
使用pyplot.matshow()函数添加绘图标题
Jun 16 #Python
浅谈matplotlib中FigureCanvasXAgg的用法
Jun 16 #Python
利用Python实现Excel的文件间的数据匹配功能
Jun 16 #Python
Pytorch 使用CNN图像分类的实现
Jun 16 #Python
You might like
php smarty截取中文字符乱码问题?gb2312/utf-8
2011/11/07 PHP
php自动加载机制的深入分析
2013/06/08 PHP
php getcwd与dirname(__FILE__)区别详解
2016/09/24 PHP
实例:尽可能写友好的Javascript代码
2006/10/09 Javascript
firefox中JS读取XML文件
2006/12/21 Javascript
js停止输出代码
2008/07/20 Javascript
5款Javascript颜色选择器
2009/10/25 Javascript
基于mootools插件实现遮罩层新手引导
2012/05/24 Javascript
jquery方法+js一般方法+js面向对象方法实现拖拽效果
2012/08/30 Javascript
extjs 3.31 TreeGrid实现静态页面加载json到TreeGrid里面
2013/04/02 Javascript
jquery等待效果示例
2014/05/01 Javascript
jQuery插件实现文字无缝向上滚动效果代码
2016/02/25 Javascript
详解JavaScript中基于原型prototype的继承特性
2016/05/05 Javascript
jQuery中设置form表单中action值的实现方法
2016/05/25 Javascript
JS实现鼠标移上去显示图片或微信二维码
2016/12/14 Javascript
jQuery实现table表格信息的展开和缩小功能示例
2018/07/21 jQuery
layui的table单击行勾选checkbox功能方法
2018/08/14 Javascript
vue: WebStorm设置快速编译运行的方法
2018/10/18 Javascript
vue-router 起步步骤详解
2019/03/26 Javascript
vue.js中ref和$refs的使用及示例讲解
2019/08/14 Javascript
tracking.js实现前端人脸识别功能
2020/04/16 Javascript
[04:02]2014DOTA2国际邀请赛 BBC每日综述中国战队将再度登顶
2014/07/21 DOTA
[03:26]《DAC最前线》之EG经理自述DOTA2经历
2015/02/02 DOTA
[01:29:17]RNG vs Liquid 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.23
2019/09/05 DOTA
python从入门到精通(DAY 1)
2015/12/20 Python
使用Python内置的模块与函数进行不同进制的数的转换
2016/03/12 Python
tensorflow 获取变量&打印权值的实例讲解
2018/06/14 Python
python中update的基本使用方法详解
2019/07/17 Python
Python占用的内存优化教程
2019/07/28 Python
印度领先的在线时尚商店:Koovs
2016/08/28 全球购物
Max&Co官网:意大利年轻女性时尚品牌
2017/05/16 全球购物
Kate Spade美国官网:纽约新兴时尚品牌,以包包闻名于世
2017/11/09 全球购物
校园门卫岗位职责
2013/12/09 职场文书
机械工程学院大学生求职信
2014/05/25 职场文书
体育课外活动总结
2014/07/08 职场文书
MySQL Innodb关键特性之插入缓冲(insert buffer)
2021/04/08 MySQL