当前位置:Gxlcms > Python > python如何实现决策树算法?(代码)

python如何实现决策树算法?(代码)

时间:2021-07-01 10:21:17 帮助过:655人阅读

本篇文章给大家带来的内容是关于python如何实现决策树算法?(代码),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

数据描述

每条数据项储存在列表中,最后一列储存结果
多条数据项形成数据集

  1. data=[[d1,d2,d3...dn,result],
  2. [d1,d2,d3...dn,result],
  3. .
  4. .
  5. [d1,d2,d3...dn,result]]

决策树数据结构

  1. class DecisionNode:
  2. '''决策树节点
  3. '''
  4. def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
  5. '''初始化决策树节点
  6. args:
  7. col -- 按数据集的col列划分数据集
  8. value -- 以value作为划分col列的参照
  9. result -- 只有叶子节点有,代表最终划分出的子数据集结果统计信息。{‘结果’:结果出现次数}
  10. rb,fb -- 代表左右子树
  11. '''
  12. self.col=col
  13. self.value=value
  14. self.results=results
  15. self.tb=tb
  16. self.fb=fb

决策树分类的最终结果是将数据项划分出了若干子集,其中每个子集的结果都一样,所以这里采用{‘结果’:结果出现次数}的方式表达每个子集

  1. def pideset(rows,column,value):
  2. '''依据数据集rows的column列的值,判断其与参考值value的关系对数据集进行拆分
  3. 返回两个数据集
  4. '''
  5. split_function=None
  6. #value是数值类型
  7. if isinstance(value,int) or isinstance(value,float):
  8. #定义lambda函数当row[column]>=value时返回true
  9. split_function=lambda row:row[column]>=value
  10. #value是字符类型
  11. else:
  12. #定义lambda函数当row[column]==value时返回true
  13. split_function=lambda row:row[column]==value
  14. #将数据集拆分成两个
  15. set1=[row for row in rows if split_function(row)]
  16. set2=[row for row in rows if not split_function(row)]
  17. #返回两个数据集
  18. return (set1,set2)
  19. def uniquecounts(rows):
  20. '''计算数据集rows中有几种最终结果,计算结果出现次数,返回一个字典
  21. '''
  22. results={}
  23. for row in rows:
  24. r=row[len(row)-1]
  25. if r not in results: results[r]=0
  26. results[r]+=1
  27. return results
  28. def giniimpurity(rows):
  29. '''返回rows数据集的基尼不纯度
  30. '''
  31. total=len(rows)
  32. counts=uniquecounts(rows)
  33. imp=0
  34. for k1 in counts:
  35. p1=float(counts[k1])/total
  36. for k2 in counts:
  37. if k1==k2: continue
  38. p2=float(counts[k2])/total
  39. imp+=p1*p2
  40. return imp
  41. def entropy(rows):
  42. '''返回rows数据集的熵
  43. '''
  44. from math import log
  45. log2=lambda x:log(x)/log(2)
  46. results=uniquecounts(rows)
  47. ent=0.0
  48. for r in results.keys():
  49. p=float(results[r])/len(rows)
  50. ent=ent-p*log2(p)
  51. return ent
  52. def build_tree(rows,scoref=entropy):
  53. '''构造决策树
  54. '''
  55. if len(rows)==0: return DecisionNode()
  56. current_score=scoref(rows)
  57. # 最佳信息增益
  58. best_gain=0.0
  59. #
  60. best_criteria=None
  61. #最佳划分
  62. best_sets=None
  63. column_count=len(rows[0])-1
  64. #遍历数据集的列,确定分割顺序
  65. for col in range(0,column_count):
  66. column_values={}
  67. # 构造字典
  68. for row in rows:
  69. column_values[row[col]]=1
  70. for value in column_values.keys():
  71. (set1,set2)=pideset(rows,col,value)
  72. p=float(len(set1))/len(rows)
  73. # 计算信息增益
  74. gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
  75. if gain>best_gain and len(set1)>0 and len(set2)>0:
  76. best_gain=gain
  77. best_criteria=(col,value)
  78. best_sets=(set1,set2)
  79. # 如果划分的两个数据集熵小于原数据集,进一步划分它们
  80. if best_gain>0:
  81. trueBranch=build_tree(best_sets[0])
  82. falseBranch=build_tree(best_sets[1])
  83. return DecisionNode(col=best_criteria[0],value=best_criteria[1],
  84. tb=trueBranch,fb=falseBranch)
  85. # 如果划分的两个数据集熵不小于原数据集,停止划分
  86. else:
  87. return DecisionNode(results=uniquecounts(rows))
  88. def print_tree(tree,indent=''):
  89. if tree.results!=None:
  90. print(str(tree.results))
  91. else:
  92. print(str(tree.col)+':'+str(tree.value)+'? ')
  93. print(indent+'T->',end='')
  94. print_tree(tree.tb,indent+' ')
  95. print(indent+'F->',end='')
  96. print_tree(tree.fb,indent+' ')
  97. def getwidth(tree):
  98. if tree.tb==None and tree.fb==None: return 1
  99. return getwidth(tree.tb)+getwidth(tree.fb)
  100. def getdepth(tree):
  101. if tree.tb==None and tree.fb==None: return 0
  102. return max(getdepth(tree.tb),getdepth(tree.fb))+1
  103. def drawtree(tree,jpeg='tree.jpg'):
  104. w=getwidth(tree)*100
  105. h=getdepth(tree)*100+120
  106. img=Image.new('RGB',(w,h),(255,255,255))
  107. draw=ImageDraw.Draw(img)
  108. drawnode(draw,tree,w/2,20)
  109. img.save(jpeg,'JPEG')
  110. def drawnode(draw,tree,x,y):
  111. if tree.results==None:
  112. # Get the width of each branch
  113. w1=getwidth(tree.fb)*100
  114. w2=getwidth(tree.tb)*100
  115. # Determine the total space required by this node
  116. left=x-(w1+w2)/2
  117. right=x+(w1+w2)/2
  118. # Draw the condition string
  119. draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
  120. # Draw links to the branches
  121. draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
  122. draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
  123. # Draw the branch nodes
  124. drawnode(draw,tree.fb,left+w1/2,y+100)
  125. drawnode(draw,tree.tb,right-w2/2,y+100)
  126. else:
  127. txt=' \n'.join(['%s:%d'%v for v in tree.results.items()])
  128. draw.text((x-20,y),txt,(0,0,0))

对测试数据进行分类(附带处理缺失数据)

  1. def mdclassify(observation,tree):
  2. '''对缺失数据进行分类
  3. args:
  4. observation -- 发生信息缺失的数据项
  5. tree -- 训练完成的决策树
  6. 返回代表该分类的结果字典
  7. '''
  8. # 判断数据是否到达叶节点
  9. if tree.results!=None:
  10. # 已经到达叶节点,返回结果result
  11. return tree.results
  12. else:
  13. # 对数据项的col列进行分析
  14. v=observation[tree.col]
  15. # 若col列数据缺失
  16. if v==None:
  17. #对tree的左右子树分别使用mdclassify,tr是左子树得到的结果字典,fr是右子树得到的结果字典
  18. tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
  19. # 分别以结果占总数比例计算得到左右子树的权重
  20. tcount=sum(tr.values())
  21. fcount=sum(fr.values())
  22. tw=float(tcount)/(tcount+fcount)
  23. fw=float(fcount)/(tcount+fcount)
  24. result={}
  25. # 计算左右子树的加权平均
  26. for k,v in tr.items():
  27. result[k]=v*tw
  28. for k,v in fr.items():
  29. # fr的结果k有可能并不在tr中,在result中初始化k
  30. if k not in result:
  31. result[k]=0
  32. # fr的结果累加到result中
  33. result[k]+=v*fw
  34. return result
  35. # col列没有缺失,继续沿决策树分类
  36. else:
  37. if isinstance(v,int) or isinstance(v,float):
  38. if v>=tree.value: branch=tree.tb
  39. else: branch=tree.fb
  40. else:
  41. if v==tree.value: branch=tree.tb
  42. else: branch=tree.fb
  43. return mdclassify(observation,branch)
  44. tree=build_tree(my_data)
  45. print(mdclassify(['google',None,'yes',None],tree))
  46. print(mdclassify(['google','France',None,None],tree))

决策树剪枝

  1. def prune(tree,mingain):
  2. '''对决策树进行剪枝
  3. args:
  4. tree -- 决策树
  5. mingain -- 最小信息增益
  6. 返回
  7. '''
  8. # 修剪非叶节点
  9. if tree.tb.results==None:
  10. prune(tree.tb,mingain)
  11. if tree.fb.results==None:
  12. prune(tree.fb,mingain)
  13. #合并两个叶子节点
  14. if tree.tb.results!=None and tree.fb.results!=None:
  15. tb,fb=[],[]
  16. for v,c in tree.tb.results.items():
  17. tb+=[[v]]*c
  18. for v,c in tree.fb.results.items():
  19. fb+=[[v]]*c
  20. #计算熵减少情况
  21. delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
  22. #熵的增加量小于mingain,可以合并分支
  23. if delta<mingain:
  24. tree.tb,tree.fb=None,None
  25. tree.results=uniquecounts(tb+fb)

以上就是python如何实现决策树算法?(代码)的详细内容,更多请关注Gxl网其它相关文章!

人气教程排行