决策树剪枝算法的python实现方法详解


Posted in Python onSeptember 18, 2019

本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典

def calcShannonEnt(dataSet): 
  numEntries = len(dataSet) 
  labelCounts = {} 
  # 给所有可能分类创建字典 
  for featVec in dataSet: 
    currentLabel = featVec[-1] 
    if currentLabel not in labelCounts.keys(): 
      labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
  shannonEnt = 0.0
  # 以2为底数计算香农熵
  for key in labelCounts:
    prob = float(labelCounts[key]) / numEntries
    shannonEnt -= prob * log(prob, 2)
  return shannonEnt
# 对离散变量划分数据集,取出该特征取值为value的所有样本
def splitDataSet(dataSet, axis, value):
  retDataSet = []
  for featVec in dataSet:
    if featVec[axis] == value:
      reducedFeatVec = featVec[:axis]
      reducedFeatVec.extend(featVec[axis + 1:])
      retDataSet.append(reducedFeatVec)
  return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
  numFeatures = len(dataSet[0]) - 1
  baseEntropy = calcShannonEnt(dataSet)
  bestInfoGain = 0.0
  bestFeature = -1
  bestSplitDict = {}
  for i in range(numFeatures):
    featList = [example[i] for example in dataSet]
    # 对连续型特征进行处理
    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
      # 产生n-1个候选划分点
      sortfeatList = sorted(featList)
      splitList = []
      for j in range(len(sortfeatList) - 1):
        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)

      bestSplitEntropy = 10000
      slen = len(splitList)
      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
      for j in range(slen):
        value = splitList[j]
        newEntropy = 0.0
        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
        prob0 = len(subDataSet0) / float(len(dataSet))
        newEntropy += prob0 * calcShannonEnt(subDataSet0)
        prob1 = len(subDataSet1) / float(len(dataSet))
        newEntropy += prob1 * calcShannonEnt(subDataSet1)
        if newEntropy < bestSplitEntropy:
          bestSplitEntropy = newEntropy
          bestSplit = j
      # 用字典记录当前特征的最佳划分点
      bestSplitDict[labels[i]] = splitList[bestSplit]
      infoGain = baseEntropy - bestSplitEntropy
    # 对离散型特征进行处理
    else:
      uniqueVals = set(featList)
      newEntropy = 0.0
      # 计算该特征下每种划分的信息熵
      for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = len(subDataSet) / float(len(dataSet))
        newEntropy += prob * calcShannonEnt(subDataSet)
      infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
      bestInfoGain = infoGain
      bestFeature = i
  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
  # 即是否小于等于bestSplitValue
  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
    bestSplitValue = bestSplitDict[labels[bestFeature]]
    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
    for i in range(shape(dataSet)[0]):
      if dataSet[i][bestFeature] <= bestSplitValue:
        dataSet[i][bestFeature] = 1
      else:
        dataSet[i][bestFeature] = 0
  return bestFeature
``def classify(inputTree, featLabels, testVec):
  firstStr = inputTree.keys()[0]
  if u'<=' in firstStr:
    featvalue = float(firstStr.split(u"<=")[1])
    featkey = firstStr.split(u"<=")[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(featkey)
    if testVec[featIndex] <= featvalue:
      judge = 1
    else:
      judge = 0
    for key in secondDict.keys():
      if judge == int(key):
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  else:
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
      if testVec[featIndex] == key:
        if type(secondDict[key]).__name__ == 'dict':
          classLabel = classify(secondDict[key], featLabels, testVec)
        else:
          classLabel = secondDict[key]
  return classLabel
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():
      classCount[vote]=0
    classCount[vote]+=1
  return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
  class_list = [example[-1] for example in train_data]
  bestFeatIndex = labels.index(feat)
  train_data = [example[bestFeatIndex] for example in train_data]
  test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
  all_feat = set(train_data)
  error = 0.0
  for value in all_feat:
    class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
    major = majorityCnt(class_feat)
    for data in test_data:
      if data[0] == value and data[1] != major:
        error += 1.0
  # print 'myTree %d' % error
  return error

测试

error = 0.0
  for i in range(len(data_test)):
    if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
      error += 1
  # print 'myTree %d' % error
  return float(error)
def testingMajor(major, data_test):
  error = 0.0
  for i in range(len(data_test)):
    if major != data_test[i][-1]:
      error += 1
  # print 'major %d' % error
  return float(error)

**递归产生决策树**

```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):
    return classList[0]
  if len(dataSet[0])==1:
    return majorityCnt(classList)
  labels_copy = copy.deepcopy(labels)
  bestFeat=chooseBestFeatureToSplit(dataSet,labels)
  bestFeatLabel=labels[bestFeat]

  if mode == "unpro" or mode == "post":
    myTree = {bestFeatLabel: {}}
  elif mode == "prev":
    if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
                                            test_data):
      myTree = {bestFeatLabel: {}}
    else:
      return majorityCnt(classList)
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)

  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    currentlabel = labels_full.index(labels[bestFeat])
    featValuesFull = [example[currentlabel] for example in data_full]
    uniqueValsFull = set(featValuesFull)

  del (labels[bestFeat])

  for value in uniqueVals:
    subLabels = labels[:]
    if type(dataSet[0][bestFeat]).__name__ == 'unicode':
      uniqueValsFull.remove(value)

    myTree[bestFeatLabel][value] = createTree(splitDataSet \
                           (dataSet, bestFeat, value), subLabels, data_full, labels_full,
                         splitDataSet \
                           (test_data, bestFeat, value), mode=mode)
  if type(dataSet[0][bestFeat]).__name__ == 'unicode':
    for value in uniqueValsFull:
      myTree[bestFeatLabel][value] = majorityCnt(classList)

  if mode == "post":
    if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
      return majorityCnt(classList)
  return myTree








<div class="se-preview-section-delimiter"></div>

```**读入数据**

```def load_data(file_name):
  with open(r"dd.csv", 'rb') as f:
   df = pd.read_csv(f,sep=",")
   print(df)
   train_data = df.values[:11, 1:].tolist()
  print(train_data)
  test_data = df.values[11:, 1:].tolist()
  labels = df.columns.values[1:-1].tolist()
  return train_data, test_data, labels





<div class="se-preview-section-delimiter"></div>

```测试并绘制树图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头


# 计算树的叶子节点数量
def getNumLeafs(myTree):
  numLeafs = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      numLeafs += getNumLeafs(secondDict[key])
    else:
      numLeafs += 1
  return numLeafs


# 计算树的最大深度
def getTreeDepth(myTree):
  maxDepth = 0
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  secondDict = myTree[firstStr]
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
      thisDepth = 1
    if thisDepth > maxDepth:
      maxDepth = thisDepth
  return maxDepth


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
              xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
              bbox=nodeType, arrowprops=arrow_args)


# 画箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):
  lens = len(txtString)
  xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
  yMid = (parentPt[1] + cntrPt[1]) / 2.0
  createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
  numLeafs = getNumLeafs(myTree)
  depth = getTreeDepth(myTree)
  firstSides = list(myTree.keys())
  firstStr = firstSides[0]
  cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
  plotMidText(cntrPt, parentPt, nodeTxt)
  plotNode(firstStr, cntrPt, parentPt, decisionNode)
  secondDict = myTree[firstStr]
  plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
  for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':
      plotTree(secondDict[key], cntrPt, str(key))
    else:
      plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
      plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
      plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
  plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


def createPlot(inTree):
  fig = plt.figure(1, facecolor='white')
  fig.clf()
  axprops = dict(xticks=[], yticks=[])
  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  plotTree.totalW = float(getNumLeafs(inTree))
  plotTree.totalD = float(getTreeDepth(inTree))
  plotTree.x0ff = -0.5 / plotTree.totalW
  plotTree.y0ff = 1.0
  plotTree(inTree, (0.5, 1.0), '')
  plt.show()
if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

if __name__ == "__main__":
  train_data, test_data, labels = load_data("dd.csv")
  data_full = train_data[:]
  labels_full = labels[:]

  mode="post"
  mode = "prev"
  mode="post"
  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
  createPlot(myTree)
  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图
决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

决策树剪枝算法的python实现方法详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
使用pdb模块调试Python程序实例
Jun 02 Python
Python 读取某个目录下所有的文件实例
Jun 23 Python
python之线程通过信号pyqtSignal刷新ui的方法
Jan 11 Python
如何利用Python模拟GitHub登录详解
Jul 15 Python
Python和Sublime整合过程图示
Dec 25 Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 Python
Python 解决相对路径问题:&quot;No such file or directory&quot;
Jun 05 Python
pycharm 2020 1.1的安装流程
Sep 29 Python
Python 利用flask搭建一个共享服务器的步骤
Dec 05 Python
python实现控制台输出颜色
Mar 02 Python
python munch库的使用解析
May 25 Python
分享几种python 变量合并方法
Mar 20 Python
python生成requirements.txt的两种方法
Sep 18 #Python
python2与python3爬虫中get与post对比解析
Sep 18 #Python
python中class的定义及使用教程
Sep 18 #Python
django创建超级用户过程解析
Sep 18 #Python
python实现网站微信登录的示例代码
Sep 18 #Python
简单了解python中的与或非运算
Sep 18 #Python
python字符串替换re.sub()方法解析
Sep 18 #Python
You might like
PHP中GET变量的使用
2006/10/09 PHP
DEDE采集大师官方留后门的删除办法
2011/01/08 PHP
推荐十款免费 WordPress 插件
2015/03/24 PHP
php插件Xajax使用方法详解
2017/08/31 PHP
JS中判断null、undefined与NaN的方法
2014/03/26 Javascript
node.js操作mongoDB数据库示例分享
2014/11/26 Javascript
js+html5实现canvas绘制镂空字体文本的方法
2015/06/05 Javascript
举例讲解AngularJS中的模块
2015/06/17 Javascript
jquery实现表单输入时提示文字滑动向上效果
2015/08/10 Javascript
微信小程序前端源码逻辑和工作流
2016/09/25 Javascript
bootstrap+jQuery实现的动态进度条功能示例
2017/05/25 jQuery
Vue-Access-Control 前端用户权限控制解决方案
2017/12/01 Javascript
vue 点击展开显示更多(点击收起部分隐藏)
2019/04/09 Javascript
微信小程序的授权实现过程解析
2019/08/02 Javascript
linux平台使用Python制作BT种子并获取BT种子信息的方法
2017/01/20 Python
python 实现数组list 添加、修改、删除的方法
2018/04/04 Python
Flask框架web开发之零基础入门
2018/12/10 Python
Django之路由层的实现
2019/09/09 Python
Python中断多重循环的思路总结
2019/10/04 Python
使用pandas实现连续数据的离散化处理方式(分箱操作)
2019/11/22 Python
Pytoch之torchvision.transforms图像变换实例
2019/12/30 Python
Python使用QQ邮箱发送邮件实例与QQ邮箱设置详解
2020/02/18 Python
python字典按照value排序方法
2020/12/28 Python
django使用多个数据库的方法实例
2021/03/04 Python
面临毕业的毕业生自荐书范文
2014/02/05 职场文书
论文指导教师评语
2014/04/28 职场文书
小学生植树节活动总结
2014/07/04 职场文书
超市周年庆活动方案
2014/08/16 职场文书
幼儿园六一儿童节活动总结
2015/02/10 职场文书
2015年公务员转正工作总结
2015/04/24 职场文书
2015年音乐教研组工作总结
2015/07/22 职场文书
2015年车间管理工作总结
2015/07/23 职场文书
小学六年级班主任工作经验交流材料
2015/11/02 职场文书
总结python多进程multiprocessing的相关知识
2021/06/29 Python
React四级菜单的实现
2022/04/08 Javascript
TypeScript实用技巧 Nominal Typing名义类型详解
2022/09/23 Javascript