tensorflow 自定义损失函数示例代码


Posted in Python onFebruary 05, 2020

这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)

我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。

如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。

显然,我宁愿预估多了,也不想预估少了。

所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。

(yhat沿用吴恩达课堂中的叫法)

import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 steps = 5000
 for i in range(steps):
  start = (i * batch_size) % dataset_size
  end = min(start + batch_size, dataset_size)
  sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
 print(sess.run(w1))

[[ 1.01934695]
[ 1.04280889]

最终结果如上面所示。

因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果:

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。

引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

总结

以上所述是小编给大家介绍的tensorflow 自定义损失函数示例,希望对大家有所帮助!

Python 相关文章推荐
python函数缺省值与引用学习笔记分享
Feb 10 Python
python抓取网页图片并放到指定文件夹
Apr 24 Python
在Python中实现贪婪排名算法的教程
Apr 17 Python
完美解决Python2操作中文名文件乱码的问题
Jan 04 Python
Python split() 函数拆分字符串将字符串转化为列的方法
Jul 16 Python
Django model select的多种用法详解
Jul 16 Python
Django中间件基础用法详解
Jul 18 Python
Django框架视图层URL映射与反向解析实例分析
Jul 29 Python
pytorch 彩色图像转灰度图像实例
Jan 13 Python
pytorch学习教程之自定义数据集
Nov 10 Python
python自动化测试之Selenium详解
Mar 13 Python
Pyhton爬虫知识之正则表达式详解
Apr 01 Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
TensorFlow实现指数衰减学习率的方法
Feb 05 #Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 #Python
解决Tensorflow sess.run导致的内存溢出问题
Feb 05 #Python
You might like
使用PHP编写发红包程序
2015/07/22 PHP
php下载文件超时时间的设置方法
2016/10/06 PHP
PHP设计模式之组合模式定义与应用示例
2020/02/01 PHP
php判断某个方法是否存在函数function_exists (),method_exists()与is_callable()区别与用法解析
2020/04/20 PHP
Tinymce+jQuery.Validation使用产生的BUG
2010/03/29 Javascript
用JQuery模仿淘宝的图片放大镜显示效果
2011/09/15 Javascript
js 动态为textbox添加下拉框数据源的方法
2014/04/24 Javascript
使用jquery+CSS实现控制打印样式
2014/12/31 Javascript
JS实现获取剪贴板内容的方法
2016/06/21 Javascript
JS 滚动事件window.onscroll与position:fixed写兼容IE6的回到顶部组件
2016/10/10 Javascript
js date 格式化
2017/02/15 Javascript
微信小程序 MD5的方法详解及实例代码
2017/03/10 Javascript
JS+HTML5 FileReader实现文件上传前本地预览功能
2020/03/27 Javascript
阿里大于短信验证码node koa2的实现代码(最新)
2017/09/07 Javascript
基于vue-simple-uploader封装文件分片上传、秒传及断点续传的全局上传插件功能
2021/02/23 Vue.js
Python发送form-data请求及拼接form-data内容的方法
2016/03/05 Python
Python基于回溯法子集树模板解决0-1背包问题实例
2017/09/02 Python
pandas 取出表中一列数据所有的值并转换为array类型的方法
2018/04/11 Python
Python利用pandas计算多个CSV文件数据值的实例
2018/04/19 Python
python删除文本中行数标签的方法
2018/05/31 Python
Python数据处理篇之Sympy系列(五)---解方程
2019/10/12 Python
Python tcp传输代码实例解析
2020/03/18 Python
Python下使用Trackbar实现绘图板
2020/10/27 Python
Python self用法详解
2020/11/28 Python
豪华床上用品 :Jennifer Adams
2019/09/15 全球购物
英国Iceland杂货店:网上食品购物
2020/12/16 全球购物
大专应届生个人的自我评价
2013/11/21 职场文书
党校培训自我鉴定范文
2014/04/10 职场文书
工地门卫岗位职责范本
2014/07/01 职场文书
2014党员民主评议个人思想剖析发言
2014/09/19 职场文书
小学教师求职信范文
2015/03/20 职场文书
读书笔记怎么写
2015/07/01 职场文书
2016最新离婚协议书范本及程序
2016/03/18 职场文书
优胜劣汰,强者为王——读《鲁滨逊漂流记》有感
2019/08/15 职场文书
自制短波长线天线频率预选器 - 成功消除B2K之流的镜像
2021/04/22 无线电
OpenCV-Python实现油画效果的实例
2021/06/08 Python