基于ID3决策树算法的实现(Python版)


Posted in Python onMay 31, 2017

实例如下:

# -*- coding:utf-8 -*-

from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#计算数据集的香农熵
def calcShannonEnt(dataSet):
  numEntries=len(dataSet)
  labelCounts={}
  #给所有可能分类创建字典
  for featVec in dataSet:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
      labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
  shannonEnt=0.0
  #以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
  return shannonEnt


#对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet,axis,value):
  retDataSet=[]
  for featVec in dataSet:
    if featVec[axis]==value:
      reducedFeatVec=featVec[:axis]
      reducedFeatVec.extend(featVec[axis+1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本集
def splitContinuousDataSet(dataSet,axis,value,direction):
  retDataSet=[]
  for featVec in dataSet:
    if direction==0:
      if featVec[axis]>value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    else:
      if featVec[axis]<=value:
        reducedFeatVec=featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
  return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet,labels):
  numFeatures=len(dataSet[0])-1
  baseEntropy=calcShannonEnt(dataSet)
  bestInfoGain=0.0
  bestFeature=-1
  bestSplitDict={}
  for i in range(numFeatures):
    featList=[example[i] for example in dataSet]
    #对连续型特征进行处理
    if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
      #产生n-1个候选划分点
      sortfeatList=sorted(featList)
      splitList=[]
      for j in range(len(sortfeatList)-1):
        splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)

      bestSplitEntropy=10000
      slen=len(splitList)
      #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value=splitList[j]
        newEntropy=0.0
        subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
        subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
        prob0=len(subDataSet0)/float(len(dataSet))
        newEntropy+=prob0*calcShannonEnt(subDataSet0)
        prob1=len(subDataSet1)/float(len(dataSet))
        newEntropy+=prob1*calcShannonEnt(subDataSet1)
        if newEntropy<bestSplitEntropy:
          bestSplitEntropy=newEntropy
          bestSplit=j
      #用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]]=splitList[bestSplit]
      infoGain=baseEntropy-bestSplitEntropy
    #对离散型特征进行处理
    else:
      uniqueVals=set(featList)
      newEntropy=0.0
      #计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,i,value)
        prob=len(subDataSet)/float(len(dataSet))
        newEntropy+=prob*calcShannonEnt(subDataSet)
      infoGain=baseEntropy-newEntropy
    if infoGain>bestInfoGain:
      bestInfoGain=infoGain
      bestFeature=i
  #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  #即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    bestSplitValue=bestSplitDict[labels[bestFeature]]
    labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature]<=bestSplitValue:
        dataSet[i][bestFeature]=1
      else:
        dataSet[i][bestFeature]=0
  return bestFeature

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)

#主程序,递归产生决策树
def createTree(dataSet,labels,data_full,labels_full):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  if type(dataSet[0][bestFeat]).__name__=='str':
    currentlabel=labels_full.index(labels[bestFeat])
    featValuesFull=[example[currentlabel] for example in data_full]
    uniqueValsFull=set(featValuesFull)
  del(labels[bestFeat])
  #针对bestFeat的每个取值,划分出一个子树。
  for value in uniqueVals:
    subLabels=labels[:]
    if type(dataSet[0][bestFeat]).__name__=='str':
      uniqueValsFull.remove(value)
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
     (dataSet,bestFeat,value),subLabels,data_full,labels_full)
  if type(dataSet[0][bestFeat]).__name__=='str':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value]=majorityCnt(classList)
  return myTree

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      numLeafs+=getNumLeafs(secondDict[key])
    else: numLeafs+=1
  return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
  maxDepth=0
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  secondDict=myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:
      maxDepth=thisDepth
  return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
  xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
  bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
  lens=len(txtString)
  xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
  yMid=(parentPt[1]+cntrPt[1])/2.0
  createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
  numLeafs=getNumLeafs(myTree)
  depth=getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr=firstSides[0]
  cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
  plotMidText(cntrPt,parentPt,nodeTxt)
  plotNode(firstStr,cntrPt,parentPt,decisionNode)
  secondDict=myTree[firstStr]
  plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
      plotTree(secondDict[key],cntrPt,str(key))
    else:
      plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
      plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
      plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
  plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
  fig=plt.figure(1,facecolor='white')
  fig.clf()
  axprops=dict(xticks=[],yticks=[])
  createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
  plotTree.totalW=float(getNumLeafs(inTree))
  plotTree.totalD=float(getTreeDepth(inTree))
  plotTree.x0ff=-0.5/plotTree.totalW
  plotTree.y0ff=1.0
  plotTree(inTree,(0.5,1.0),'')
  plt.show()

df=pd.read_csv('watermelon_4_3.csv')
data=df.values[:,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)
print(myTree)
createPlot(myTree)

最终结果如下:

{'texture': {'blur': 0, 'little_blur': {'touch': {'soft_stick': 1, 'hard_smooth': 0}}, 'distinct': {'density<=0.38149999999999995': {0: 1, 1: 0}}}}

得到的决策树如下:

基于ID3决策树算法的实现(Python版)

参考资料:

《机器学习实战》

《机器学习》周志华著

以上这篇基于ID3决策树算法的实现(Python版)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 将RGB图像转换为Pytho灰度图像的实例
Nov 14 Python
详谈pandas中agg函数和apply函数的区别
Apr 20 Python
Python 字符串与数字输出方法
Jul 16 Python
Python面向对象思想与应用入门教程【类与对象】
Apr 12 Python
学习python分支结构
May 17 Python
简单了解python的内存管理机制
Jul 08 Python
如何使用python操作vmware
Jul 27 Python
Python unittest框架操作实例解析
Apr 13 Python
jupyter notebook tensorflow打印device信息实例
Apr 20 Python
Python可以实现栈的结构吗
May 27 Python
python基础之爬虫入门
May 10 Python
Python中可变和不可变对象的深入讲解
Aug 02 Python
Python基础知识_浅谈用户交互
May 31 #Python
python数据类型_字符串常用操作(详解)
May 30 #Python
python数据类型_元组、字典常用操作方法(介绍)
May 30 #Python
node.js获取参数的常用方法(总结)
May 29 #Python
老生常谈python函数参数的区别(必看篇)
May 29 #Python
Python进阶_关于命名空间与作用域(详解)
May 29 #Python
浅谈对yield的初步理解
May 29 #Python
You might like
php 全文搜索和替换的实现代码
2008/07/29 PHP
php实现的简单检验登陆类
2015/06/18 PHP
php分割合并两个字符串的函数实例
2015/06/19 PHP
在IE,Firefox,Safari,Chrome,Opera浏览器上调试javascript
2008/12/02 Javascript
分页栏的web标准实现
2011/11/01 Javascript
改变隐藏的input中value的值代码
2013/12/30 Javascript
JavaScript fontcolor方法入门实例(按照指定的颜色来显示字符串)
2014/10/17 Javascript
jQuery+css3实现转动的正方形效果(附demo源码下载)
2016/01/27 Javascript
JavaScript的Backbone.js框架环境搭建及Hellow world示例
2016/05/07 Javascript
jQuery Ajax 上传文件处理方式介绍(推荐)
2016/06/30 Javascript
jQuery+ajax的资源回收处理机制分析
2017/01/07 Javascript
jquery获取select,option所有的value和text的实例
2017/03/06 Javascript
基于angular实现模拟微信小程序swiper组件
2017/06/11 Javascript
JavaScript严格模式下关于this的几种指向详解
2017/07/12 Javascript
自制简易打赏功能的实例
2017/09/02 Javascript
深入解析Vue源码实例挂载与编译流程实现思路详解
2019/05/05 Javascript
vue通过数据过滤实现表格合并
2020/11/30 Javascript
vue实现将一个数组内的相同数据进行合并
2019/11/07 Javascript
Python3连接SQLServer、Oracle、MySql的方法
2018/06/28 Python
pyenv与virtualenv安装实现python多版本多项目管理
2019/08/17 Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
2019/12/31 Python
导入tensorflow:ImportError: libcublas.so.9.0 报错
2020/01/06 Python
新手学python应该下哪个版本
2020/06/11 Python
canvas基础之图形验证码的示例
2018/01/02 HTML / CSS
英国大码女性时装零售商:Evans
2018/08/29 全球购物
如何让Java程序执行效率更高
2014/06/25 面试题
怎样有效的进行自我评价
2013/10/06 职场文书
求职信格式范本
2013/11/15 职场文书
医学检验专业个人求职信范文
2013/12/04 职场文书
预备党员公开承诺书
2014/05/28 职场文书
2014幼儿教师个人工作总结
2014/12/03 职场文书
岳庙导游词
2015/02/04 职场文书
施工员岗位职责
2015/02/10 职场文书
担保贷款承诺书
2015/04/30 职场文书
【超详细】八大排序算法的各项比较以及各自特点
2021/03/31 Python
QT连接MYSQL数据库的详细步骤
2021/07/07 MySQL