python实现基于信息增益的决策树归纳


Posted in Python onDecember 18, 2018

本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy
 
#加载训练数据
#文件格式:属性标号,是否连续【yes|no】,属性说明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)
 
#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)
 
#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)
 
 
root_attr_dict = {}
for line in attribute_file :
  line = line.strip()
  fld_list = line.split(',')
  root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])
 
class_dict = {}
for line in class_desc_file :
  line = line.strip()
  fld_list = line.split(',')
  class_dict[int(fld_list[0])] = fld_list[1]
  
trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  
  if c_id not in class_member_set_dict :
    class_member_set_dict[c_id] = set()
  class_member_set_dict[c_id].add(rec_id)
  trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)
  
attribute_file.close()
class_desc_file.close()
trainning_data_file.close()
 
class_possibility_dict = {}
for c_id in class_member_set_dict :
  class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)  
 
#等待分类的数据
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
  line = line.strip()
  fld_list = line.split(',')
  rec_id = int(fld_list[0])
  a1 = int(fld_list[1])
  a2 = int(fld_list[2])
  a3 = float(fld_list[3])
  c_id = int(fld_list[4])
  data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()
 
 
 
 
'''
决策树的表达
结点的需求:
1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点
2、保存分类所需信息
3、子结点列表
每个结点用Tuple类型表示
元素一是整形,取值123 分别对应两种分裂类型
元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号
元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点
对于3来说1代表属于判别集合
'''
 
  
#对于一个成员列表,计算其熵
#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和
def get_entropy( member_list ) :
  #成员总数
  mem_cnt = len(member_list)
  #首先找出member中所包含的分类
  class_dict = {}
  for mem_id in member_list :
    c_id = trainning_data_dict[mem_id][3]
    if c_id not in class_dict :
      class_dict[c_id] = set()
    class_dict[c_id].add(mem_id)
  
  tmp_sum = 0.0
  for c_id in class_dict :
    pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
    tmp_sum += pi * mlab.log2(pi)
  tmp_sum = -tmp_sum
  return tmp_sum
    
 
def attribute_selection_method( member_list , attribute_dict ) :
  #先计算原始的熵
  info_D = get_entropy(member_list)
  
  max_info_Gain = 0.0
  attr_get = 0
  split_point = 0.0
  for attr_id in attribute_dict :
    #对于每一个属性计算划分后的熵
    #信息增益等于原始的熵减去划分后的熵
    info_D_new = 0
    #如果是连续属性
    if attribute_dict[attr_id][0] == 'yes' :
      #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵
      #找出其中最小的,作为此连续属性的划分点
      value_list = []
      for mem_id in member_list :
        value_list.append(trainning_data_dict[mem_id][attr_id - 1])
      
      #获取相邻元素的中值序列
      mid_value_list = []
      value_list.sort()
      #print value_list
      last_value = None
      for value in value_list :
        if value == last_value :
          continue
        if last_value is not None :
          mid_value_list.append((last_value+value)/2)
        last_value = value
      #print mid_value_list
      #对于中值序列做循环
      #计算以此值做为划分点的熵
      #总的熵等于两个划分的熵乘以两个划分的比重
      min_info = 1000000000.0
      total_mens = len(member_list) + 0.0
      for mid_value in mid_value_list :
        #小于mid_value的mem
        less_list = []
        #大于
        more_list = []
        for tmp_mem_id in member_list :
          if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
            less_list.append(tmp_mem_id)
          else :
            more_list.append(tmp_mem_id)
        sum_info = len(less_list)/total_mens * get_entropy(less_list) \
        + len(more_list)/total_mens * get_entropy(more_list)
        
        if sum_info < min_info :
          min_info = sum_info
          split_point = mid_value
          
      info_D_new = min_info
    #如果是离散属性
    else :
      #计算划分后的熵
      #采用循环累加的方式
      attr_value_member_dict = {} #键为attribute value , 值为memberlist
      for tmp_mem_id in member_list :
        attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
        if attr_value not in attr_value_member_dict :
          attr_value_member_dict[attr_value] = []
        attr_value_member_dict[attr_value].append(tmp_mem_id)
      #将每个离散值的熵乘以比重加到这上面
      total_mens = len(member_list) + 0.0
      sum_info = 0.0
      for a_value in attr_value_member_dict :
        sum_info += len(attr_value_member_dict[a_value])/total_mens \
        * get_entropy(attr_value_member_dict[a_value])
      
      info_D_new = sum_info
    
    info_Gain = info_D - info_D_new
    if info_Gain > max_info_Gain :
      max_info_Gain = info_Gain
      attr_get = attr_id
  
  #如果是离散的
  #print 'attr_get ' + str(attr_get)
  if attribute_dict[attr_get][0] == 'no' :
    return (1 , attr_get , split_point)
  else :  
    return (2 , attr_get , split_point)
  #第三类先不考虑
 
def get_decision_tree(father_node , key , member_list , attr_dict ) :
  #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键
  #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号
  class_set = set()
  for mem_id in member_list :
    class_set.add(trainning_data_dict[mem_id][3])
  if len(class_set) == 1 :
    father_node[2][key] = (0 , (1 , class_set) , {} )
    return
  
  #检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号
  #如果几个类的成员等量,则打印提示,并且全部添加到set里面
  if not attr_dict :
    class_cnt_dict = {}
    for mem_id in member_list :
      c_id = trainning_data_dict[mem_id][3]
      if c_id not in class_cnt_dict :
        class_cnt_dict[c_id] = 1
      else :
        class_cnt_dict[c_id] += 1
        
    class_set = set()
    max_cnt = 0
    for c_id in class_cnt_dict :
      if class_cnt_dict[c_id] > max_cnt :
        max_cnt = class_cnt_dict[c_id]
        class_set.clear()
        class_set.add(c_id)
      elif class_cnt_dict[c_id] == max_cnt :
        class_set.add(c_id)
    
    if len(class_set) > 1 :
      print 'more than one class !'
    
    father_node[2][key] = (0 , (1 , class_set ) , {} )
    return
  
  #找出最好的分区方案 , 暂不考虑第三种划分方法
  #比较所有离散属性和所有连续属性的所有中值点划分的信息增益
  split_criterion = attribute_selection_method(member_list , attr_dict)
  #print split_criterion
  selected_plan_id = split_criterion[0]
  selected_attr_id = split_criterion[1]
  
  #如果采用的是离散属性做为分区方案,删除这个属性
  new_attr_dict = copy(attr_dict)
  if attr_dict[selected_attr_id][0] == 'no' :
    del new_attr_dict[selected_attr_id]
  
  #建立一个结点new_node,father_node[2][key] = new_node
  #然后对new node的每一个key , sub_member_list,
  #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
  #实现递归
  ele2 = ( selected_attr_id , set() )
  #如果是1 , ele2保存所有离散值
  if selected_plan_id == 1 :
    for mem_id in member_list :
      ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
  #如果是2,ele2保存分裂点
  elif selected_plan_id == 2 :
    ele2[1].add(split_criterion[2])
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
    
  new_node = ( selected_plan_id , ele2 , {} )
  father_node[2][key] = new_node
  
  #生成KEY,并递归调用
  if selected_plan_id == 1 :
    #每个attr_value是一个key
    attr_value_member_dict = {}
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value not in attr_value_member_dict :
        attr_value_member_dict[attr_value] = []
      attr_value_member_dict[attr_value].append(mem_id)
    for attr_value in attr_value_member_dict :
      get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
    pass
  elif selected_plan_id == 2 :
    #key 只有12 , 小于等于分裂点的是1 , 大于的是2
    less_list = []
    more_list = []
    for mem_id in member_list :
      attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
      if attr_value <= split_criterion[2] :
        less_list.append(mem_id)
      else :
        more_list.append(mem_id)
    #if len(less_list) != 0 :
    get_decision_tree(new_node , 1 , less_list , new_attr_dict)
    #if len(more_list) != 0 :
    get_decision_tree(new_node , 2 , more_list , new_attr_dict)
    pass
  #如果是3则保存判别集合,先不管
  else :
    print 'not completed'
    pass
  
def get_class_sub(node , tp ) :
  #
  attr_id = node[1][0]
  plan_id = node[0]
  key = 0
  if plan_id == 0 :
    return node[1][1]
  elif plan_id == 1 :
    key = tp[attr_id - 1]
  elif plan_id == 2 :
    split_point = tuple(node[1][1])[0]
    attr_value = tp[attr_id - 1]
    if attr_value <= split_point :
      key = 1
    else :
      key = 2
  else :
    print 'error'
    return set()
    
  return get_class_sub(node[2][key] , tp )
 
def get_class(r_node , tp) :
  #tp为一组属性值
  if r_node[0] != -1 :
    print 'error'
    return set()
  
  if 1 in r_node[2] :
    return get_class_sub(r_node[2][1] , tp)
  else :
    print 'error'
    return set()
  
  
if __name__ == '__main__' :
  root_node = ( -1 , set() , {} )
  mem_list = trainning_data_dict.keys()
  get_decision_tree(root_node , 1 , mem_list , root_attr_dict )
 
  #测试分类器的准确率
  diff_cnt = 0
  for mem_id in data_to_classify_dict :
    c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
    if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
      print tuple(c_id)[0]
      print data_to_classify_dict[mem_id][3]
      print 'different'
      diff_cnt += 1
  print diff_cnt

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python备份文件的脚本
Aug 11 Python
django接入新浪微博OAuth的方法
Jun 29 Python
Django 导出 Excel 代码的实例详解
Aug 11 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
基于tensorflow加载部分层的方法
Jul 26 Python
Django如何开发简单的查询接口详解
May 17 Python
Django CBV类的用法详解
Jul 26 Python
python3光学字符识别模块tesserocr与pytesseract的使用详解
Feb 26 Python
python实现二分类和多分类的ROC曲线教程
Jun 15 Python
使用OpenCV实现道路车辆计数的使用方法
Jul 15 Python
Python自动登录QQ的实现示例
Aug 28 Python
python实战之用emoji表情生成文字
May 08 Python
Django实现一对多表模型的跨表查询方法
Dec 18 #Python
Python实现字典排序、按照list中字典的某个key排序的方法示例
Dec 18 #Python
python实现求特征选择的信息增益
Dec 18 #Python
python实现连续图文识别
Dec 18 #Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 #Python
Python列表list排列组合操作示例
Dec 18 #Python
python实现二维插值的三维显示
Dec 17 #Python
You might like
农民C键的运用技巧
2020/03/04 星际争霸
PHP页面间传递参数实例代码
2008/06/05 PHP
PHP文件读写操作之文件读取方法详解
2011/01/13 PHP
php自动识别文件编码并转换为UTF-8的方法
2014/06/12 PHP
php常用hash加密函数
2014/11/22 PHP
jQuery向下滚动即时加载内容实现的瀑布流效果
2016/01/07 PHP
php中数组最简单的使用方法
2020/12/27 PHP
jQuery实现多按钮单击变色
2014/11/27 Javascript
Jquery中使用show()与hide()方法动画显示和隐藏图片
2015/10/08 Javascript
jquery编写Tab选项卡滚动导航切换特效
2020/07/17 Javascript
客户端验证用户名和密码的方法详解
2016/06/16 Javascript
一个简单不报错的summernote 图片上传案例
2016/07/11 Javascript
AngularJS ng-bind 指令简单实现
2016/07/30 Javascript
jQuery选择器实例应用
2017/01/05 Javascript
一文让你彻底搞清楚javascript中的require、import与export
2017/09/24 Javascript
vue+Vue Router多级侧导航切换路由(页面)的实现代码
2018/12/20 Javascript
实例讲解v-if和v-show的区别
2019/01/31 Javascript
vue语法自动转typescript(解放双手)
2019/09/18 Javascript
JavaScript实现点击自制菜单效果
2021/02/02 Javascript
[54:10]Spirit vs NB Supermajor小组赛 A组败者组决赛 BO3 第一场 6.2
2018/06/03 DOTA
python实现简单的TCP代理服务器
2014/10/08 Python
Python中强大的命令行库click入门教程
2016/12/26 Python
python如何通过实例方法名字调用方法
2018/03/21 Python
python中时间模块的基本使用教程
2019/05/14 Python
python正则-re的用法详解
2019/07/28 Python
浅谈python3中input输入的使用
2019/08/02 Python
Python实现网页截图(PyQT5)过程解析
2019/08/12 Python
儿科护士实习自我鉴定
2013/10/17 职场文书
电子商务个人自荐信
2013/12/12 职场文书
代理协议书
2014/04/22 职场文书
新生开学寄语大全
2015/05/28 职场文书
世界名著读书笔记
2015/06/25 职场文书
学习社交礼仪心得体会
2016/01/22 职场文书
导游词之天津盘山
2019/11/01 职场文书
golang 实用库gotable的具体使用
2021/07/01 Golang
Vue OpenLayer 为地图绘制风场效果
2022/04/24 Vue.js