softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Linux下将Python的Django项目部署到Apache服务器
Dec 24 Python
Python解析json文件相关知识学习
Mar 01 Python
python daemon守护进程实现
Aug 27 Python
python中执行shell的两种方法总结
Jan 10 Python
python3个性签名设计实现代码
Jun 19 Python
使用pandas模块读取csv文件和excel表格,并用matplotlib画图的方法
Jun 22 Python
Django组件之cookie与session的使用方法
Jan 10 Python
Python中整数的缓存机制讲解
Feb 16 Python
Python3爬虫之自动查询天气并实现语音播报
Feb 21 Python
Django上线部署之IIS的配置方法
Aug 22 Python
python opencv 实现读取、显示、写入图像的方法
Jun 08 Python
Python 找出英文单词列表(list)中最长单词链
Dec 14 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
根德Grundig S400/S500/S700电路分析
2021/03/02 无线电
磨咖啡豆的密诀
2021/03/03 冲泡冲煮
PHP中对缓冲区的控制实现代码
2013/09/29 PHP
Windows和Linux中php代码调试工具Xdebug的安装与配置详解
2014/05/08 PHP
PHP中strtr字符串替换用法详解
2014/11/26 PHP
PHP数组去重比较快的实现方式
2016/01/19 PHP
thinkphp jquery实现图片上传和预览效果
2020/07/22 PHP
jQuery遍历之next()、nextAll()方法使用实例
2014/11/08 Javascript
JavaScript面向对象的实现方法小结
2015/04/14 Javascript
浅析BootStrap Treeview的简单使用
2016/10/12 Javascript
AngularJS递归指令实现Tree View效果示例
2016/11/07 Javascript
微信小程序 同步请求授权的详解
2017/08/04 Javascript
使用vue-cli webpack 快速搭建项目的代码
2018/11/21 Javascript
VuePress 中如何增加用户登录功能
2019/11/29 Javascript
Node Mongoose用法详解【Mongoose使用、Schema、对象、model文档等】
2020/05/13 Javascript
详解webpack的clean-webpack-plugin插件报错
2020/10/16 Javascript
python实现的守护进程(Daemon)用法实例
2015/06/02 Python
轻松实现python搭建微信公众平台
2016/02/16 Python
Python+matplotlib实现华丽的文本框演示代码
2018/01/22 Python
浅谈python的深浅拷贝以及fromkeys的用法
2019/03/08 Python
使用python-opencv读取视频,计算视频总帧数及FPS的实现
2019/12/10 Python
Python实现桌面翻译工具【新手必学】
2020/02/12 Python
python中get和post有什么区别
2020/06/19 Python
无需压缩软件,用python帮你操作压缩包
2020/08/17 Python
python 基于opencv 实现一个鼠标绘图小程序
2020/12/11 Python
阿玛尼美妆英国官网:Giorgio Armani Beauty英国
2019/03/28 全球购物
会计学财务管理专业个人的自我评价
2013/10/19 职场文书
自动化职业生涯规划书范文
2014/01/03 职场文书
带薪年假请假条
2014/02/04 职场文书
同学会主持词
2014/03/18 职场文书
暑期辅导班宣传单
2015/07/14 职场文书
2016年基层党组织公开承诺书
2016/03/25 职场文书
辞职信怎么写?
2019/05/21 职场文书
pytorch通过训练结果的复现设置随机种子
2021/06/01 Python
浅谈resultMap的用法及关联结果集映射
2021/06/30 Java/Android
修改Nginx配置返回指定content-type的方法
2022/09/23 Servers