Python机器学习之决策树算法实例详解


Posted in Python onDecember 06, 2017

本文实例讲述了Python机器学习之决策树算法。分享给大家供大家参考,具体如下:

决策树学习是应用最广泛的归纳推理算法之一,是一种逼近离散值目标函数的方法,在这种方法中学习到的函数被表示为一棵决策树。决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些从数据集中创造的规则。决策树的优点为:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。缺点为:可能产生过度匹配的问题。决策树适于处理离散型和连续型的数据。

在决策树中最重要的就是如何选取用于划分的特征

在算法中一般选用ID3,D3算法的核心问题是选取在树的每个节点要测试的特征或者属性,希望选择的是最有助于分类实例的属性。如何定量地衡量一个属性的价值呢?这里需要引入熵和信息增益的概念。熵是信息论中广泛使用的一个度量标准,刻画了任意样本集的纯度。

假设有10个训练样本,其中6个的分类标签为yes,4个的分类标签为no,那熵是多少呢?在该例子中,分类的数目为2(yes,no),yes的概率为0.6,no的概率为0.4,则熵为 :

Python机器学习之决策树算法实例详解

Python机器学习之决策树算法实例详解

其中value(A)是属性A所有可能值的集合,Python机器学习之决策树算法实例详解是S中属性A的值为v的子集,即Python机器学习之决策树算法实例详解。上述公式的第一项为原集合S的熵,第二项是用A分类S后熵的期望值,该项描述的期望熵就是每个子集的熵的加权和,权值为属于的样本占原始样本S的比例Python机器学习之决策树算法实例详解。所以Gain(S, A)是由于知道属性A的值而导致的期望熵减少。

完整的代码:

# -*- coding: cp936 -*-
from numpy import *
import operator
from math import log
import operator
def createDataSet():
  dataSet = [[1,1,'yes'],
    [1,1,'yes'],
    [1,0,'no'],
    [0,1,'no'],
    [0,1,'no']]
  labels = ['no surfacing','flippers']
  return dataSet, labels
def calcShannonEnt(dataSet):
  numEntries = len(dataSet)
  labelCounts = {} # a dictionary for feature
  for featVec in dataSet:
    currentLabel = featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  for key in labelCounts:
    #print(key)
    #print(labelCounts[key])
    prob = float(labelCounts[key])/numEntries
    #print(prob)
    shannonEnt -= prob * log(prob,2)
  return shannonEnt
#按照给定的特征划分数据集
#根据axis等于value的特征将数据提出
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet
#选取特征,划分数据集,计算得出最好的划分数据集的特征
def chooseBestFeatureToSplit(dataSet):
  numFeatures = len(dataSet[0]) - 1 #剩下的是特征的个数
  baseEntropy = calcShannonEnt(dataSet)#计算数据集的熵,放到baseEntropy中
  bestInfoGain = 0.0;bestFeature = -1 #初始化熵增益
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet] #featList存储对应特征所有可能得取值
    uniqueVals = set(featList)
    newEntropy = 0.0
    for value in uniqueVals:#下面是计算每种划分方式的信息熵,特征i个,每个特征value个值
      subDataSet = splitDataSet(dataSet, i ,value)
      prob = len(subDataSet)/float(len(dataSet)) #特征样本在总样本中的权重
      newEntropy = prob * calcShannonEnt(subDataSet)
    infoGain = baseEntropy - newEntropy #计算i个特征的信息熵
    #print(i)
    #print(infoGain)
    if(infoGain > bestInfoGain):
      bestInfoGain = infoGain
      bestFeature = i
  return bestFeature
#如上面是决策树所有的功能模块
#得到原始数据集之后基于最好的属性值进行划分,每一次划分之后传递到树分支的下一个节点
#递归结束的条件是程序遍历完成所有的数据集属性,或者是每一个分支下的所有实例都具有相同的分类
#如果所有实例具有相同的分类,则得到一个叶子节点或者终止快
#如果所有属性都已经被处理,但是类标签依然不是确定的,那么采用多数投票的方式
#返回出现次数最多的分类名称
def majorityCnt(classList):
  classCount = {}
  for vote in classList:
    if vote not in classCount.keys():classCount[vote] = 0
    classCount[vote] += 1
  sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
  return sortedClassCount[0][0]
#创建决策树
def createTree(dataSet,labels):
  classList = [example[-1] for example in dataSet]#将最后一行的数据放到classList中,所有的类别的值
  if classList.count(classList[0]) == len(classList): #类别完全相同不需要再划分
    return classList[0]
  if len(dataSet[0]) == 1:#这里为什么是1呢?就是说特征数为1的时候
    return majorityCnt(classList)#就返回这个特征就行了,因为就这一个特征
  bestFeat = chooseBestFeatureToSplit(dataSet)
  print('the bestFeatue in creating is :')
  print(bestFeat)
  bestFeatLabel = labels[bestFeat]#运行结果'no surfacing'
  myTree = {bestFeatLabel:{}}#嵌套字典,目前value是一个空字典
  del(labels[bestFeat])
  featValues = [example[bestFeat] for example in dataSet]#第0个特征对应的取值
  uniqueVals = set(featValues)
  for value in uniqueVals: #根据当前特征值的取值进行下一级的划分
    subLabels = labels[:]
    myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  return myTree
#对上面简单的数据进行小测试
def testTree1():
  myDat,labels=createDataSet()
  val = calcShannonEnt(myDat)
  print 'The classify accuracy is: %.2f%%' % val
  retDataSet1 = splitDataSet(myDat,0,1)
  print (myDat)
  print(retDataSet1)
  retDataSet0 = splitDataSet(myDat,0,0)
  print (myDat)
  print(retDataSet0)
  bestfeature = chooseBestFeatureToSplit(myDat)
  print('the bestFeatue is :')
  print(bestfeature)
  tree = createTree(myDat,labels)
  print(tree)

对应的结果是:

>>> import TREE
>>> TREE.testTree1()
The classify accuracy is: 0.97%
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'no'], [1, 'no']]
the bestFeatue is :
0
the bestFeatue in creating is :
0
the bestFeatue in creating is :
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

最好再增加使用决策树的分类函数

同时因为构建决策树是非常耗时间的,因为最好是将构建好的树通过 python 的 pickle 序列化对象,将对象保存在磁盘上,等到需要用的时候再读出

def classify(inputTree,featLabels,testVec):
  firstStr = inputTree.keys()[0]
  secondDict = inputTree[firstStr]
  featIndex = featLabels.index(firstStr)
  key = testVec[featIndex]
  valueOfFeat = secondDict[key]
  if isinstance(valueOfFeat, dict):
    classLabel = classify(valueOfFeat, featLabels, testVec)
  else: classLabel = valueOfFeat
  return classLabel
def storeTree(inputTree,filename):
  import pickle
  fw = open(filename,'w')
  pickle.dump(inputTree,fw)
  fw.close()
def grabTree(filename):
  import pickle
  fr = open(filename)
  return pickle.load(fr)

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python 过滤字符串的技巧,map与itertools.imap
Sep 06 Python
python实现通过shelve修改对象实例
Sep 26 Python
使用Python编写一个在Linux下实现截图分享的脚本的教程
Apr 24 Python
Python基于pygame实现图片代替鼠标移动效果
Nov 11 Python
Windows下安装python MySQLdb遇到的问题及解决方法
Mar 16 Python
vscode 远程调试python的方法
Dec 01 Python
对Pyhon实现静态变量全局变量的方法详解
Jan 11 Python
PyQt5重写QComboBox的鼠标点击事件方法
Jun 25 Python
Python整数与Numpy数据溢出问题解决
Sep 11 Python
python修改FTP服务器上的文件名
Sep 11 Python
Python读写文件模式和文件对象方法实例详解
Sep 17 Python
python可视化分析绘制带趋势线的散点图和边缘直方图
Jun 25 Python
快速入门python学习笔记
Dec 06 #Python
Python中django学习心得
Dec 06 #Python
Python标准库inspect的具体使用方法
Dec 06 #Python
读取本地json文件,解析json(实例讲解)
Dec 06 #Python
Python语言描述最大连续子序列和
Dec 05 #Python
python matplotlib坐标轴设置的方法
Dec 05 #Python
详解K-means算法在Python中的实现
Dec 05 #Python
You might like
将PHP作为Shell脚本语言使用
2006/10/09 PHP
php db类库进行数据库操作
2009/03/19 PHP
PHP分页函数代码(简单实用型)
2010/12/02 PHP
PHP递归算法的详细示例分析
2013/02/19 PHP
PHP Cookie学习笔记
2016/08/23 PHP
php读取和保存base64编码的图片内容
2017/04/22 PHP
ThinkPHP5.1框架页面跳转及修改跳转页面模版示例
2019/05/06 PHP
JS input文本框禁用右键和复制粘贴功能的代码
2010/04/15 Javascript
鼠标滚轮改变图片大小的示例代码
2013/11/20 Javascript
javascript实现左右控制无缝滚动
2014/12/31 Javascript
使用AOP改善javascript代码
2015/05/01 Javascript
深入分析jsonp协议原理
2015/09/26 Javascript
JavaScript 是什么意思
2016/09/22 Javascript
nodeJs链接Mysql做增删改查的简单操作
2017/02/04 NodeJs
javascript中BOM基础知识总结
2017/02/14 Javascript
详解AngularJS用Interceptors来统一处理HTTP请求和响应
2017/06/08 Javascript
JS实现获取自定义属性data值的方法示例
2018/12/19 Javascript
layer弹出层扩展主题的方法
2019/09/11 Javascript
vscode自定义vue模板的实现
2021/01/27 Vue.js
在Python中用keys()方法返回字典键的教程
2015/05/21 Python
教你用 Python 实现微信跳一跳(Mac+iOS版)
2018/01/04 Python
Python实现可获取网易页面所有文本信息的网易网络爬虫功能示例
2018/01/15 Python
Python爬虫框架Scrapy实例代码
2018/03/04 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
2020/01/08 Python
python3跳出一个循环的实例操作
2020/08/18 Python
python 元组和列表的区别
2020/12/30 Python
基于MUI框架使用HTML5实现的二维码扫描功能
2018/03/01 HTML / CSS
美丽的现代设计家具:2Modern
2018/07/26 全球购物
加拿大品牌鞋包连锁店:Little Burgundy
2021/02/28 全球购物
凌阳科技股份有限公司C++程序员面试题笔试题
2014/11/20 面试题
幼儿园评语大全
2014/04/17 职场文书
巾帼文明岗申报材料
2014/05/01 职场文书
工作汇报开头与结尾怎么写
2014/11/08 职场文书
2015年度培训工作总结范文
2015/04/02 职场文书
小学二年级班主任工作经验交流材料
2015/11/02 职场文书
php随机生成验证码,php随机生成数字,php随机生成数字加字母!
2021/04/01 PHP