softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python os模块中的isfile()和isdir()函数均返回false问题解决方法
Feb 04 Python
Django框架中数据的连锁查询和限制返回数据的方法
Jul 17 Python
python生成式的send()方法(详解)
May 08 Python
Python用于学习重要算法的模块pygorithm实例浅析
Aug 16 Python
pygame游戏之旅 添加icon和bgm音效的方法
Nov 21 Python
python调用接口的4种方式代码实例
Nov 19 Python
python实现信号时域统计特征提取代码
Feb 26 Python
配置python的编程环境之Anaconda + VSCode的教程
Mar 29 Python
jupyter notebook实现显示行号
Apr 13 Python
浅谈pymysql查询语句中带有in时传递参数的问题
Jun 05 Python
python 图像增强算法实现详解
Jan 24 Python
python 远程执行命令的详细代码
Feb 15 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP中文件上传的一个问题
2010/09/04 PHP
mongo Table类文件 获取MongoCursor(游标)的实现方法分析
2013/07/01 PHP
PHP生成迅雷、快车、旋风等软件的下载链接代码实例
2014/05/12 PHP
ThinkPHP权限认证Auth实例详解
2014/07/22 PHP
PHP使用Nginx实现反向代理
2017/09/20 PHP
基于JQuery模仿苹果桌面的Dock效果(初级版)
2012/10/15 Javascript
Ext4.2的Ext.grid.plugin.RowExpander无法触发事件解决办法
2014/08/15 Javascript
jQuery判断元素上是否绑定了指定事件的方法
2015/03/17 Javascript
详解JavaScript for循环中发送AJAX请求问题
2020/06/23 Javascript
JS不用正则验证输入的字符串是否为空(包含空格)的实现代码
2016/06/14 Javascript
基于js粘贴事件paste简单解析以及遇到的坑
2017/09/07 Javascript
ES6之模版字符串的具体使用
2018/05/17 Javascript
vue 自动化路由实现代码
2019/09/03 Javascript
vue-next/runtime-core 源码阅读指南详解
2019/10/25 Javascript
原生js实现点击轮播切换图片
2020/02/11 Javascript
[01:09]模型精美,特效酷炫!TI9不朽宝藏Ⅰ鉴赏
2019/05/10 DOTA
[34:41]夜魇凡尔赛茶话会 第二期02:你画我猜
2021/03/11 DOTA
python 网络编程详解及简单实例
2017/04/25 Python
python flask实现分页效果
2017/06/27 Python
python 读写文件,按行修改文件的方法
2018/07/12 Python
python 调用pyautogui 实时获取鼠标的位置、移动鼠标的方法
2019/08/27 Python
Python模块future用法原理详解
2020/01/20 Python
微软开源最强Python自动化神器Playwright(不用写一行代码)
2021/01/05 Python
Python 将代码转换为可执行文件脱离python环境运行(步骤详解)
2021/01/25 Python
Python tkinter之Bind(绑定事件)的使用示例
2021/02/05 Python
美国受信赖的教育产品供应商:Nest Learning
2018/06/14 全球购物
买卖正宗运动鞋:GOAT
2019/12/06 全球购物
学院书画协会部门职责
2013/11/28 职场文书
中式面点餐厅创业计划书
2014/01/29 职场文书
致1500米运动员广播稿
2014/02/07 职场文书
优秀护士先进事迹
2014/05/08 职场文书
大学优秀班集体申报材料
2014/05/23 职场文书
小学生安全教育广播稿
2014/10/20 职场文书
读《儒林外史》有感:少一些功利,多一些真诚
2020/01/19 职场文书
详解Vue slot插槽
2021/11/20 Vue.js
MySQL事务操作的四大特性以及并发事务问题
2022/04/12 MySQL