机器学习python实战之决策树


Posted in Python onNovember 01, 2017

决策树原理:从数据集中找出决定性的特征对数据集进行迭代划分,直到某个分支下的数据都属于同一类型,或者已经遍历了所有划分数据集的特征,停止决策树算法。

每次划分数据集的特征都有很多,那么我们怎么来选择到底根据哪一个特征划分数据集呢?这里我们需要引入信息增益和信息熵的概念。

一、信息增益

划分数据集的原则是:将无序的数据变的有序。在划分数据集之前之后信息发生的变化称为信息增益。知道如何计算信息增益,我们就可以计算根据每个特征划分数据集获得的信息增益,选择信息增益最高的特征就是最好的选择。首先我们先来明确一下信息的定义:符号xi的信息定义为 l(xi)=-log2 p(xi),p(xi)为选择该类的概率。那么信息源的熵H=-∑p(xi)·log2 p(xi)。根据这个公式我们下面编写代码计算香农熵

def calcShannonEnt(dataSet):
 NumEntries = len(dataSet)
 labelsCount = {}
 for i in dataSet:
  currentlabel = i[-1]
  if currentlabel not in labelsCount.keys():
   labelsCount[currentlabel]=0
  labelsCount[currentlabel]+=1
 ShannonEnt = 0.0
 for key in labelsCount:
  prob = labelsCount[key]/NumEntries
  ShannonEnt -= prob*log(prob,2)
 return ShannonEnt

上面的自定义函数我们需要在之前导入log方法,from math import log。 我们可以先用一个简单的例子来测试一下

def createdataSet():
 #dataSet = [['1','1','yes'],['1','0','no'],['0','1','no'],['0','0','no']]
 dataSet = [[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]
 labels = ['no surfacing','flippers']
 return dataSet,labels

机器学习python实战之决策树

这里的熵为0.811,当我们增加数据的类别时,熵会增加。这里更改后的数据集的类别有三种‘yes'、‘no'、‘maybe',也就是说数据越混乱,熵就越大。

机器学习python实战之决策树

分类算法出了需要计算信息熵,还需要划分数据集。决策树算法中我们对根据每个特征划分的数据集计算一次熵,然后判断按照哪个特征划分是最好的划分方式。

def splitDataSet(dataSet,axis,value):
 retDataSet = []
 for featVec in dataSet:
  if featVec[axis] == value:
   reducedfeatVec = featVec[:axis]
   reducedfeatVec.extend(featVec[axis+1:])
   retDataSet.append(reducedfeatVec)
 return retDataSet

axis表示划分数据集的特征,value表示特征的返回值。这里需要注意extend方法和append方法的区别。举例来说明这个区别

机器学习python实战之决策树

下面我们测试一下划分数据集函数的结果:

机器学习python实战之决策树

axis=0,value=1,按myDat数据集的第0个特征向量是否等于1进行划分。

接下来我们将遍历整个数据集,对每个划分的数据集计算香农熵,找到最好的特征划分方式

def choosebestfeatureToSplit(dataSet):
 Numfeatures = len(dataSet)-1
 BaseShannonEnt = calcShannonEnt(dataSet)
 bestInfoGain=0.0
 bestfeature = -1
 for i in range(Numfeatures):
  featlist = [example[i] for example in dataSet]
  featSet = set(featlist)
  newEntropy = 0.0
  for value in featSet:
   subDataSet = splitDataSet(dataSet,i,value)
   prob = len(subDataSet)/len(dataSet)
   newEntropy += prob*calcShannonEnt(subDataSet) 
  infoGain = BaseShannonEnt-newEntropy
  if infoGain>bestInfoGain:
   bestInfoGain=infoGain
   bestfeature = i
 return bestfeature

信息增益是熵的减少或数据无序度的减少。最后比较所有特征中的信息增益,返回最好特征划分的索引。函数测试结果为

机器学习python实战之决策树

接下来开始递归构建决策树,我们需要在构建前计算列的数目,查看算法是否使用了所有的属性。这个函数跟跟第二章的calssify0采用同样的方法

def majorityCnt(classlist):
 ClassCount = {}
 for vote in classlist:
  if vote not in ClassCount.keys():
   ClassCount[vote]=0
  ClassCount[vote]+=1
 sortedClassCount = sorted(ClassCount.items(),key = operator.itemgetter(1),reverse = True)
 return sortedClassCount[0][0]

def createTrees(dataSet,labels):
 classList = [example[-1] for example in dataSet]
 if classList.count(classList[0]) == len(classList):
  return classList[0]
 if len(dataSet[0])==1:
  return majorityCnt(classList)
 bestfeature = choosebestfeatureToSplit(dataSet)
 bestfeatureLabel = labels[bestfeature]
 myTree = {bestfeatureLabel:{}}
 del(labels[bestfeature])
 featValue = [example[bestfeature] for example in dataSet]
 uniqueValue = set(featValue)
 for value in uniqueValue:
  subLabels = labels[:]
  myTree[bestfeatureLabel][value] = createTrees(splitDataSet(dataSet,bestfeature,value),subLabels)
 return myTree

最终决策树得到的结果如下:

机器学习python实战之决策树

有了如上的结果,我们看起来并不直观,所以我们接下来用matplotlib注解绘制树形图。matplotlib提供了一个注解工具annotations,它可以在数据图形上添加文本注释。我们先来测试一下这个注解工具的使用。

import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = 'sawtooth',fc = '0.8')
leafNode = dict(boxstyle = 'sawtooth',fc = '0.8')
arrow_args = dict(arrowstyle = '<-')

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
 createPlot.ax1.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',\
       xytext = centerPt,textcoords = 'axes fraction',\
       va = 'center',ha = 'center',bbox = nodeType,\
       arrowprops = arrow_args)
 
def createPlot():
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 createPlot.ax1 = plt.subplot(111,frameon = False)
 plotNode('test1',(0.5,0.1),(0.1,0.5),decisionNode)
 plotNode('test2',(0.8,0.1),(0.3,0.8),leafNode)
 plt.show()

机器学习python实战之决策树

测试过这个小例子之后我们就要开始构建注解树了。虽然有xy坐标,但在如何放置树节点的时候我们会遇到一些麻烦。所以我们需要知道有多少个叶节点,树的深度有多少层。下面的两个函数就是为了得到叶节点数目和树的深度,两个函数有相同的结构,从第一个关键字开始遍历所有的子节点,使用type()函数判断子节点是否为字典类型,若为字典类型,则可以认为该子节点是一个判断节点,然后递归调用函数getNumleafs(),使得函数遍历整棵树,并返回叶子节点数。第2个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一

def getNumleafs(myTree):
 numLeafs=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__=='dict':
   numLeafs+=getNumleafs(secondDict[key])
  else:
   numLeafs+=1
 return numLeafs

def getTreeDepth(myTree):
 maxdepth=0
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 secondDict = myTree[firstStr]
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   thedepth=1+getTreeDepth(secondDict[key])
  else:
   thedepth=1
  if thedepth>maxdepth:
   maxdepth=thedepth
 return maxdepth

测试结果如下

机器学习python实战之决策树

我们先给出最终的决策树图来验证上述结果的正确性

机器学习python实战之决策树

可以看出树的深度确实是有两层,叶节点的数目是3。接下来我们给出绘制决策树图的关键函数,结果就得到上图中决策树。

def plotMidText(cntrPt,parentPt,txtString):
 xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
 yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
 createPlot.ax1.text(xMid,yMid,txtString)
 
def plotTree(myTree,parentPt,nodeTxt):
 numLeafs = getNumleafs(myTree)
 depth = getTreeDepth(myTree)
 key_sorted= sorted(myTree.keys())
 firstStr = key_sorted[0]
 cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
 plotMidText(cntrPt,parentPt,nodeTxt)
 plotNode(firstStr,cntrPt,parentPt,decisionNode)
 secondDict = myTree[firstStr]
 plotTree.yOff -= 1.0/plotTree.totalD
 for key in secondDict.keys():
  if type(secondDict[key]).__name__ == 'dict':
   plotTree(secondDict[key],cntrPt,str(key))
  else:
   plotTree.xOff+=1.0/plotTree.totalW
   plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
   plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
 plotTree.yOff+=1.0/plotTree.totalD
 
def createPlot(inTree):
 fig = plt.figure(1,facecolor = 'white')
 fig.clf()
 axprops = dict(xticks = [],yticks = [])
 createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
 plotTree.totalW = float(getNumleafs(inTree))
 plotTree.totalD = float(getTreeDepth(inTree))
 plotTree.xOff = -0.5/ plotTree.totalW; plotTree.yOff = 1.0
 plotTree(inTree,(0.5,1.0),'')
 plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
几个提升Python运行效率的方法之间的对比
Apr 03 Python
Zabbix实现微信报警功能
Oct 09 Python
用Python将动态GIF图片倒放播放的方法
Nov 02 Python
简单谈谈Python中的json与pickle
Jul 19 Python
Python实现XML文件解析的示例代码
Feb 05 Python
更改Python的pip install 默认安装依赖路径方法详解
Oct 27 Python
Python完成毫秒级抢淘宝大单功能
Jun 06 Python
python 爬取学信网登录页面的例子
Aug 13 Python
Pytorch GPU显存充足却显示out of memory的解决方式
Jan 13 Python
python 使用递归回溯完美解决八皇后的问题
Feb 26 Python
Django实现任意文件上传(最简单的方法)
Jun 03 Python
python“静态”变量、实例变量与本地变量的声明示例
Nov 13 Python
详解Python开发中如何使用Hook技巧
Nov 01 #Python
python利用标准库如何获取本地IP示例详解
Nov 01 #Python
你眼中的Python大牛 应该都有这份书单
Oct 31 #Python
Python生成数字图片代码分享
Oct 31 #Python
python使用标准库根据进程名如何获取进程的pid详解
Oct 31 #Python
Python列表删除的三种方法代码分享
Oct 31 #Python
Python文件的读写和异常代码示例
Oct 31 #Python
You might like
php 安全过滤函数代码
2011/05/07 PHP
PHP取进制余数函数代码
2012/01/19 PHP
源码分析 Laravel 重复执行同一个队列任务的原因
2017/12/25 PHP
PHP实现QQ登录的开原理和实现过程
2018/02/04 PHP
用php实现分页效果的示例代码
2020/12/10 PHP
js右键菜单效果代码
2007/07/21 Javascript
JQuery的一些小应用收集
2010/03/27 Javascript
基于Jquery的开发个代阴影的对话框效果代码
2011/07/28 Javascript
jquery 鼠标滑动显示详情应用示例
2014/01/24 Javascript
$.each与$().each的区别示例介绍
2014/03/20 Javascript
Javascript实现图片加载从模糊到清晰显示的方法
2016/06/21 Javascript
footer定位页面底部(代码分享)
2017/03/07 Javascript
Angular项目中$scope.$apply()方法的使用详解
2017/07/26 Javascript
AngularJS模糊查询功能实现代码(过滤内容下拉菜单排序过滤敏感字符验证判断后添加表格信息)
2017/10/24 Javascript
ng-events类似ionic中Events的angular全局事件
2018/09/05 Javascript
Vue2.x通用编辑组件的封装及应用详解
2019/05/28 Javascript
JavaScript代码简化技巧实例解析
2020/09/09 Javascript
[08:47]DOTA2每周TOP10 精彩击杀集锦vol.6
2014/06/25 DOTA
[01:20:38]完美世界DOTA2联赛 GXR vs IO 第一场 11.07
2020/11/09 DOTA
[01:20:47]DOTA2-DPC中国联赛 正赛 Ehome vs Magma BO3 第一场 1月19日
2021/03/11 DOTA
使用Python判断IP地址合法性的方法实例
2014/03/13 Python
python 打印出所有的对象/模块的属性(实例代码)
2016/09/11 Python
利用Python代码实现数据可视化的5种方法详解
2018/03/25 Python
对python3标准库httpclient的使用详解
2018/12/18 Python
python模拟菜刀反弹shell绕过限制【推荐】
2019/06/25 Python
Python3标准库之threading进程中管理并发操作方法
2020/03/30 Python
英国钻石公司:British Diamond Company
2020/02/16 全球购物
电厂厂长岗位职责
2014/01/02 职场文书
年终奖发放方案
2014/06/02 职场文书
民族学专业求职信
2014/07/28 职场文书
自强自立美德少年事迹材料
2014/08/16 职场文书
民政工作个人总结
2015/02/28 职场文书
电影开国大典观后感
2015/06/04 职场文书
五年级作文之劳动作文
2019/11/12 职场文书
浅谈 JavaScript 沙箱Sandbox
2021/11/02 Javascript
详解JAVA的控制语句
2021/11/11 Java/Android