softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python开发中如何使用Hook技巧
Nov 01 Python
Python基于socket模块实现UDP通信功能示例
Apr 10 Python
python中使用iterrows()对dataframe进行遍历的实例
Jun 09 Python
Python判断字符串是否为字母或者数字(浮点数)的多种方法
Aug 03 Python
对python中Librosa的mfcc步骤详解
Jan 09 Python
python批量创建指定名称的文件夹
Mar 21 Python
numpy数组做图片拼接的实现(concatenate、vstack、hstack)
Nov 08 Python
django 框架实现的用户注册、登录、退出功能示例
Nov 28 Python
python 实现分组求和与分组累加求和代码
May 18 Python
pycharm 关闭search everywhere的解决操作
Jan 15 Python
如何用python插入独创性声明
Mar 31 Python
Python学习之异常中的finally使用详解
Mar 16 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
IIS6.0中配置php服务全过程解析
2013/08/07 PHP
ThinkPHP中ajax使用实例教程
2014/08/22 PHP
PHP记录页面停留时间的方法
2016/03/30 PHP
详解PHP使用Redis存储session时的一个Warning定位
2017/07/05 PHP
PHP+Ajax实现的博客文章添加类别功能示例
2018/03/29 PHP
php集成开发环境详解
2019/09/24 PHP
写js时遇到的一些小问题
2010/12/06 Javascript
可兼容IE的获取及设置cookie的jquery.cookie函数方法
2013/09/02 Javascript
如何解决Jquery库及其他库之间的$命名冲突
2013/09/15 Javascript
javascript中的括号()用法小结
2014/04/14 Javascript
函数式 JavaScript(一)简介
2014/07/07 Javascript
jqGrid表格应用之新增与删除数据附源码下载
2015/12/02 Javascript
使用bootstrapValidator插件进行动态添加表单元素并校验
2016/09/28 Javascript
jQuery插件FusionCharts绘制的2D条状图效果【附demo源码】
2017/05/13 jQuery
全面解析jQuery中的$(window)与$(document)的用法区别
2017/08/15 jQuery
解决vue 路由变化页面数据不刷新的问题
2018/03/13 Javascript
Vue 获取数组键名的方法
2018/06/21 Javascript
vue引入axios同源跨域问题
2018/09/27 Javascript
koa2使用ejs和nunjucks作为模板引擎的使用
2018/11/27 Javascript
微信小程序canvas分享海报功能
2019/10/31 Javascript
JavaScript数组常用的增删改查与其他属性详解
2020/10/13 Javascript
[04:31]2016国际邀请赛中国区预选赛妖精采访
2016/06/27 DOTA
[50:58]2018DOTA2亚洲邀请赛 4.1 小组赛 B组 Mineski vs EG
2018/04/03 DOTA
Python 实现网页自动截图的示例讲解
2018/05/17 Python
Python 获取numpy.array索引值的实例
2019/12/06 Python
pandas factorize实现将字符串特征转化为数字特征
2019/12/19 Python
python enumerate内置函数用法总结
2020/01/07 Python
python中upper是做什么用的
2020/07/20 Python
python 实现批量图片识别并翻译
2020/11/02 Python
面向对象编程是如何提高软件开发水平的
2014/05/06 面试题
浅谈react路由传参的几种方式
2021/03/23 Javascript
乡镇消防安全责任书
2014/07/23 职场文书
一份没有按时交货失信于客户的检讨书
2014/09/19 职场文书
个人四风问题对照检查材料
2014/09/26 职场文书
开票员岗位职责
2015/02/12 职场文书
优秀范文:《但愿人长久》教学反思3篇
2019/10/24 职场文书