softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现爬虫统计学校BBS男女比例之数据处理(三)
Dec 31 Python
总结python实现父类调用两种方法的不同
Jan 15 Python
Django实现登录随机验证码的示例代码
Jun 20 Python
Python爬取成语接龙类网站
Oct 19 Python
python中正则表达式与模式匹配
May 07 Python
python实现人工智能Ai抠图功能
Sep 05 Python
Python标准库itertools的使用方法
Jan 17 Python
python如何将两张图片生成为全景图片
Mar 05 Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 Python
关于keras中keras.layers.merge的用法说明
May 23 Python
Python实现打包成库供别的模块调用
Jul 13 Python
基于python实现银行管理系统
Apr 20 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
使用PHP下载CSS文件中的图片的代码
2013/09/24 PHP
PHP判断是否有Get参数的方法
2014/05/05 PHP
php中curl、fsocket、file_get_content三个函数的使用比较
2014/05/09 PHP
php代码架构的八点注意事项
2016/01/25 PHP
PHP解压tar.gz格式文件的方法
2016/02/14 PHP
php微信浏览器分享设置以及回调详解
2016/08/01 PHP
php is_writable判断文件是否可写实例代码
2016/10/13 PHP
解决php用mysql方式连接数据库出现Deprecated报错问题
2019/12/25 PHP
鼠标图片振动代码
2006/07/06 Javascript
Javascript 代码也可以变得优美的实现方法
2009/06/22 Javascript
js继承的实现代码
2010/08/05 Javascript
javascript计算用户打开网页的停留时间
2014/01/09 Javascript
js日期、星座的级联显示代码
2014/01/23 Javascript
node.js应用后台守护进程管理器Forever安装和使用实例
2014/06/01 Javascript
jquery根据锚点offset值实现动画切换
2014/09/11 Javascript
Javascript 中创建自定义对象的方法汇总
2014/12/04 Javascript
讲解JavaScript中for...in语句的使用方法
2015/06/03 Javascript
JavaScript String 对象常用方法详解
2016/05/13 Javascript
js计算系统当前日期是星期几的方法
2016/07/14 Javascript
js控制按钮,防止频繁点击响应的实例
2017/02/15 Javascript
node.js连接mysql与基本用法示例
2019/01/05 Javascript
js实现AI五子棋人机大战
2020/05/28 Javascript
[05:53]敌法师的金色冠名ID"BurNIng",是传说,是荣耀
2020/07/11 DOTA
python计算时间差的方法
2015/05/20 Python
python字符串和常用数据结构知识总结
2019/05/21 Python
python异步编程 使用yield from过程解析
2019/09/25 Python
Pandas-Cookbook 时间戳处理方式
2019/12/07 Python
Python图像阈值化处理及算法比对实例解析
2020/06/19 Python
CSS3 box-sizing属性
2009/04/17 HTML / CSS
CSS3实现文本垂直排列的方法
2018/07/10 HTML / CSS
英国最大的独立家具零售商:Furniture Village
2016/09/06 全球购物
澳大利亚优质的家居用品和生活方式公司:Bed Bath N’ Table
2019/04/16 全球购物
《登鹳雀楼》教学反思
2014/04/09 职场文书
医院财务人员岗位职责
2015/04/14 职场文书
解决redis sentinel 频繁主备切换的问题
2021/04/12 Redis
LeetCode189轮转数组python示例
2022/08/05 Python