浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
tornado框架blog模块分析与使用
Nov 21 Python
python模块之StringIO使用示例
Apr 08 Python
在Python中用has_key()方法查找键是否存在的教程
May 21 Python
Python回调函数用法实例详解
Jul 02 Python
Python按行读取文件的简单实现方法
Jun 22 Python
Python 实现 贪吃蛇大作战 代码分享
Sep 07 Python
Python处理PDF及生成多层PDF实例代码
Apr 24 Python
django框架实现一次性上传多个文件功能示例【批量上传】
Jun 19 Python
Python如何筛选序列中的元素的方法实现
Jul 15 Python
python安装本地whl的实例步骤
Oct 12 Python
一文读懂python Scrapy爬虫框架
Feb 24 Python
Django基础CBV装饰器和中间件
Mar 22 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
php延迟静态绑定实例分析
2015/02/08 PHP
百度工程师讲PHP函数的实现原理及性能分析(二)
2015/05/13 PHP
什么是PHP文件?如何打开PHP文件?
2017/06/27 PHP
javascript div 遮罩层封锁整个页面
2009/07/10 Javascript
Javascript(AJAX)解析XML的代码(兼容FIREFOX/IE)
2010/07/11 Javascript
基于JQuery的数字改变的动画效果--可用来做计数器
2010/08/11 Javascript
jQuery学习笔记之jQuery的动画
2010/12/22 Javascript
getElementByIdx_x js自定义getElementById函数
2012/01/24 Javascript
node.js中的http.response.setHeader方法使用说明
2014/12/14 Javascript
基于jQuery实现文本框只能输入数字(小数、整数)
2016/01/14 Javascript
JavaScript入门教程之引用类型
2016/05/04 Javascript
node.js调用Chrome浏览器打开链接地址的方法
2017/05/17 Javascript
bootstrap daterangepicker双日历时间段选择控件详解
2017/06/15 Javascript
详解angular分页插件tm.pagination二次触发问题解决方案
2018/07/20 Javascript
移动端(微信等使用vConsole调试console的方法
2019/03/05 Javascript
Vue + Elementui实现多标签页共存的方法
2019/06/12 Javascript
vue集成chart.js的实现方法
2019/08/20 Javascript
jquery实现点击弹出对话框
2020/02/08 jQuery
Python中用Descriptor实现类级属性(Property)详解
2014/09/18 Python
Python 将pdf转成图片的方法
2018/04/23 Python
django中ORM模型常用的字段的使用方法
2019/03/05 Python
python实现银联支付和支付宝支付接入
2019/05/07 Python
HTML5录音实践总结(Preact)
2020/05/07 HTML / CSS
理肤泉英国官网:La Roche-Posay英国
2019/01/14 全球购物
德国在线香料制造商:Gewürzland
2020/03/10 全球购物
Tessabit美国:集世界奢侈品和设计师品牌的意大利精品买手店
2020/06/29 全球购物
阅兵口号
2014/06/19 职场文书
党的生日演讲稿
2014/09/10 职场文书
2014年中学生检讨书大全
2014/10/09 职场文书
2015年学校教育教学工作总结
2015/04/22 职场文书
法律服务所工作总结
2015/08/10 职场文书
关于办理居住证的介绍信模板
2019/11/27 职场文书
redis限流的实际应用
2021/04/24 Redis
【海涛教你打DOTA】剑圣第一人称视角解说
2022/04/01 DOTA
MongoDB误操作后使用oplog恢复数据
2022/04/11 MongoDB
MySQL视图概念以及相关应用
2022/04/19 MySQL