浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门篇之数字
Oct 20 Python
Python实现简单的用户交互方法详解
Sep 25 Python
python多任务及返回值的处理方法
Jan 22 Python
python中时间模块的基本使用教程
May 14 Python
Python实现字符串中某个字母的替代功能
Oct 21 Python
基于python调用psutil模块过程解析
Dec 20 Python
pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解
Jan 03 Python
Python的PIL库中getpixel方法的使用
Apr 09 Python
关于jupyter打开之后不能直接跳转到浏览器的解决方式
Apr 13 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 Python
Python如何给函数库增加日志功能
Aug 04 Python
pytorch 实现多个Dataloader同时训练
May 29 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
PHP 5昨天隆重推出--PHP 5/Zend Engine 2.0新特性
2006/10/09 PHP
微信支付开发动态链接Native支付
2016/07/12 PHP
php实现简单加入购物车功能
2017/03/07 PHP
解决 firefox 不支持 document.all的方法
2007/03/12 Javascript
javascript 强制刷新页面的实现代码
2009/12/13 Javascript
关于jQuery中.attr()和.prop()的问题探讨
2013/09/06 Javascript
Mac/Windows下如何安装Node.js
2013/11/22 Javascript
不要使用jQuery触发原生事件的方法
2014/03/03 Javascript
在Python中使用glob模块查找文件路径的方法
2015/06/17 Javascript
jquery.validate提示错误信息位置方法
2016/01/22 Javascript
JS 拦截全局ajax请求实例解析
2016/11/29 Javascript
Bootstrap基本样式学习笔记之标签(5)
2016/12/07 Javascript
如何解决jQuery EasyUI 已打开Tab重新加载问题
2016/12/19 Javascript
javascript中apply/call和bind的使用
2017/02/15 Javascript
jQuery EasyUI之验证框validatebox实例详解
2017/04/10 jQuery
微信小程序实战之自定义抽屉菜单(7)
2017/04/18 Javascript
Angular4开发解决跨域问题详解
2017/08/28 Javascript
Vue 2.0学习笔记之Vue中的computed属性
2017/10/16 Javascript
在JavaScript中如何访问暂未存在的嵌套对象
2019/06/18 Javascript
vue实现多个echarts根据屏幕大小变化而变化实例
2020/07/19 Javascript
[02:41]DOTA2英雄基础教程 冥魂大帝
2014/01/16 DOTA
Python面向对象程序设计之私有属性及私有方法示例
2019/04/08 Python
Python MySQLdb 执行sql语句时的参数传递方式
2020/03/04 Python
使用sklearn对多分类的每个类别进行指标评价操作
2020/06/11 Python
Nobody Denim官网:购买高级女士牛仔裤
2021/03/15 全球购物
EJB2和EJB3在架构上的不同点
2014/09/29 面试题
公司同意接收函
2014/01/13 职场文书
交通事故检查书范文
2014/01/30 职场文书
红色故事演讲稿
2014/05/22 职场文书
霸气队列口号
2014/06/18 职场文书
运动会加油稿100字
2014/09/19 职场文书
自我工作评价范文
2015/03/06 职场文书
勤俭节约倡议书范文
2015/04/29 职场文书
党员带头倡议书
2015/04/29 职场文书
考研经验交流会策划书
2015/11/02 职场文书
2016秋季幼儿园开学寄语
2015/12/03 职场文书