使用文本注解绘制树节点

`decisionNode = dict(boxstyle=’sawtooth’,fc=’0.8’) #对非叶子节点框进行设置
_leafNode = dict(boxstyle=”round4”,fc=”0.8”)
#对叶子节点框进行设置
arrow_args = dict(arrowstyle=”<-“) #对箭头进行设置

def plotNode(nodeTxt,centerPt,parentPt,nodeType): #节点文字,节点框位置,父节点位置,节点类型
_createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords=”axes fraction”,
xytext=centerPt,textcoords=”axes fraction”,
va=”center”,ha=”center”,bbox=nodeType,arrowprops=arrow_args)

def createPlot():
fig = plt.figure(1,facecolor=”white”)
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode(“决策节点”,(0.5,0.1),(0.1,0.5),decisionNode)
plotNode(“叶节点”,(0.8,0.1),(0.3,0.8),leafNode)
plt.show()`

调用及结果:

image.png


构造注解树

获取决策树节点的数目和树的层数

`# 获取叶节点的数目,参数为构造好的字典类型的决策树
_def getNumLeafs(decisionTree):
numLeafs = 0
#初始化节点数目
firstStr = list(decisionTree.keys())[0] #获取根节点
secondDict = decisionTree[firstStr] #获取根节点的值,即子一级的树结构
# 对子树继续进行操作
for key in secondDict.keys():
# 若子树根节点的值依然为字典类型,即代表该节点仍为非叶子节点,则继续对下一级子树进行操作
if type(secondDict[key]).name == “dict”:
numLeafs += getNumLeafs(secondDict[key])
# 若当前key的值不为字典类型,即代表该key所代表的节点为叶子节点,就对节点数目+1
else:
numLeafs += 1
return numLeafs
# 返回节点数目

获取树的层数,参数为构造好的字典类型的决策树
def getTreeDepth(decisionTree):
maxDepth = 0
#初始化树的层数
firstStr = list(decisionTree.keys())[0] #获取最外层字典的键,即根节点
secondDict = decisionTree[firstStr] #获取根节点的值,即子一级的树结构
# 对子树继续进行操作
for key in secondDict.keys():
# 若子树根节点的值依然为字典类型,即代表该节点仍为非叶子节点,则继续对下一级子树进行操作
if type(secondDict[key]).name == “dict”:
thisDepth = 1 + getTreeDepth(secondDict[key])
# 若当前key的值不为字典类型,即代表该key所代表的节点为叶子节点,不再有子树,故对当前层数参数thisDepth赋为1
else:
thisDepth = 1
if thisDepth > maxDepth:
# 对最大层数maxDepth进行更新
maxDepth = thisDepth
return maxDepth
# 返回决策树的最大层数_`


绘制树

`#绘制树
# 在父子节点间填充文本信息
_def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[0]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
#计算宽和高
def plotTree(decisionTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(decisionTree)
#获取树的节点数目
depth = getTreeDepth(decisionTree) #获取树的层数
firstStr = list(decisionTree.keys())[0] #获取根节点
cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalw,plotTree.yoff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = decisionTree[firstStr]
# 减少y偏移
plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalw
for key in secondDict.keys():
if type(secondDict[key])._name
== “dict”:
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalw
plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),cntrPt,leafNode)
plotMidText((plotTree.xoff,plotTree.yoff),cntrPt,str(key))
plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD

def createPlot(inTree):
fig = plt.figure(1,facecolor=”white”)
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalw = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xoff = -0.5 / plotTree.totalw;plotTree.yoff = 1.0;
plotTree(inTree,(0.5,1.0),””)
plt.show()`

调用及结果:

dataSet, labels = createDataSet()<br />tree = dt.createTree(dataSet, labels)<br />print("tree = ",tree)<br />createPlot(tree)
image.png
image.png