keras 自定义loss model.add_loss的使用详解


Posted in Python onJune 22, 2020

一点见解,不断学习,欢迎指正

1、自定义loss层作为网络一层加进model,同时该loss的输出作为网络优化的目标函数

from keras.models import Model
import keras.layers as KL
import keras.backend as K
import numpy as np
from keras.utils.vis_utils import plot_model
 
x_train=np.random.normal(1,1,(100,784))
 
x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
def custom_loss1(y_true,y_pred):
 return K.mean(K.abs(y_true-y_pred))
loss1=KL.Lambda(lambda x:custom_loss1(*x),name='loss1')([x,x_in])
 
model = Model(x_in, [loss1])
model.get_layer('loss1').output#取出loss
model.add_loss(loss1)#作为网络优化的目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
#
model.fit(x_train, None, epochs=5)

2、自定义loss,作为网络优化的目标函数

x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
 
model = Model(x_in, x)
loss = K.mean((x - x_in)**2)
model.add_loss(loss)#只是作为loss优化目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
model.fit(x_train, None, epochs=5)

补充知识:keras load_weights fine-tune

分享一个小技巧,就是在构建网络模型的时候,不要怕麻烦,给每一层都定义一个名字,这样在复用之前的参数权重的时候,除了官网给的先加载权重,再冻结权重之外,你可以通过简单的修改层的名字来达到加载之前训练的权重的目的,假设权重文件保存为model_pretrain.h5 ,重新使用的时候,我把想要复用的层的名字设置成一样的,然后

model.load_weights('model_pretrain.h5', by_name=True)

以上这篇keras 自定义loss model.add_loss的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
如何用Python实现简单的Markdown转换器
Jul 16 Python
Python JSON格式数据的提取和保存的实现
Mar 22 Python
Python生成rsa密钥对操作示例
Apr 26 Python
Django中信号signals的简单使用方法
Jul 04 Python
PIL图像处理模块paste方法简单使用详解
Jul 17 Python
python并发编程 Process对象的其他属性方法join方法详解
Aug 20 Python
python文字转语音实现过程解析
Nov 12 Python
python DES加密与解密及hex输出和bs64格式输出的实现代码
Apr 13 Python
Python实现弹球小游戏
Aug 01 Python
Python制作简单的剪刀石头布游戏
Dec 10 Python
Python爬虫实现selenium处理iframe作用域问题
Jan 27 Python
利用Python读取微信朋友圈的多种方法总结
Aug 23 Python
Python项目跨域问题解决方案
Jun 22 #Python
python os模块在系统管理中的应用
Jun 22 #Python
解决tensorflow读取本地MNITS_data失败的原因
Jun 22 #Python
python实现猜数游戏(保存游戏记录)
Jun 22 #Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
Jun 22 #Python
在Mac中配置Python虚拟环境过程解析
Jun 22 #Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 #Python
You might like
php将session放入memcached的设置方法
2014/02/14 PHP
php在apache环境下实现gzip配置方法
2015/04/02 PHP
PHP中phar包的使用教程
2017/06/14 PHP
JavaScript 设计模式 富有表现力的Javascript(一)
2010/05/26 Javascript
JS中处理与当前时间间隔的函数代码
2012/05/23 Javascript
Array.prototype.concat不是通用方法反驳[译]
2012/09/20 Javascript
jQuery+CSS 半开折叠效果原理及代码(自写)
2013/03/04 Javascript
jQuery把表单元素变为json对象
2013/11/06 Javascript
javascript实现的弹出层背景置灰-模拟(easyui dialog)
2013/12/27 Javascript
jquery解析xml字符串简单示例
2014/04/11 Javascript
2014 HTML5/CSS3热门动画特效TOP10
2014/12/07 Javascript
jQuery中outerHeight()方法用法实例
2015/01/19 Javascript
JavaScript 事件入门知识
2015/04/13 Javascript
使用AngularJS编写较为优美的JavaScript代码指南
2015/06/19 Javascript
AngularJS整合Springmvc、Spring、Mybatis搭建开发环境
2016/02/25 Javascript
jQuery.Form上传文件操作
2017/02/05 Javascript
深入理解js中的加载事件
2017/02/08 Javascript
js实现PC端和移动端刮卡效果
2020/03/27 Javascript
用jQuery实现圆点图片轮播效果
2017/03/19 Javascript
解决vue router使用 history 模式刷新后404问题
2017/07/19 Javascript
jQuery实现简单弹幕制作
2020/12/10 jQuery
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
2014/08/22 Python
利用python操作SQLite数据库及文件操作详解
2017/09/22 Python
Django models filter筛选条件详解
2020/03/16 Python
详解Python中pyautogui库的最全使用方法
2020/04/01 Python
Django在Model保存前记录日志实例
2020/05/14 Python
HTML5实现经典坦克大战坦克乱走还能发出一个子弹
2013/09/02 HTML / CSS
下列程序在32位linux或unix中的结果是什么
2015/01/26 面试题
你所在的项目是如何确定版本号的
2015/12/28 面试题
畜牧兽医本科生个人的自我评价
2013/10/11 职场文书
励志演讲稿500字
2014/08/21 职场文书
入股合作协议书
2014/10/12 职场文书
2014年技术员工作总结
2014/11/18 职场文书
中学生社区服务活动报告
2015/02/05 职场文书
2015年环境整治工作总结
2015/05/22 职场文书
追悼会家属答谢词
2015/09/29 职场文书