1. import matplotlib.pyplot as plt
    2. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    3. leafNode = dict(boxstyle="round4", fc="0.8")
    4. arrow_args = dict(arrowstyle="<-")
    5. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    6. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
    7. xytext=centerPt, textcoords='axes fraction', \
    8. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    9. def getNumLeafs(myTree):
    10. numLeafs = 0
    11. firstStr = list(myTree.keys())[0]
    12. secondDict = myTree[firstStr]
    13. for key in secondDict.keys():
    14. if type(secondDict[key]).__name__ == 'dict':
    15. numLeafs += getNumLeafs(secondDict[key])
    16. else:
    17. numLeafs += 1
    18. return numLeafs
    19. def getTreeDepth(myTree):
    20. maxDepth = 0
    21. firstStr = list(myTree.keys())[0]
    22. secondDict = myTree[firstStr]
    23. for key in secondDict.keys():
    24. if type(secondDict[key]).__name__ == 'dict':
    25. thisDepth = getTreeDepth(secondDict[key]) + 1
    26. else:
    27. thisDepth = 1
    28. if thisDepth > maxDepth:
    29. maxDepth = thisDepth
    30. return maxDepth
    31. def plotMidText(cntrPt, parentPt, txtString):
    32. xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    33. yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    34. createPlot.ax1.text(xMid, yMid, txtString)
    35. def plotTree(myTree, parentPt, nodeTxt):
    36. numLeafs = getNumLeafs(myTree)
    37. depth = getTreeDepth(myTree)
    38. firstStr = list(myTree.keys())[0]
    39. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    40. plotMidText(cntrPt, parentPt, nodeTxt)
    41. plotNode(firstStr, cntrPt, parentPt, decisionNode)
    42. secondDict = myTree[firstStr]
    43. plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    44. for key in secondDict.keys():
    45. if type(secondDict[key]).__name__ == 'dict':
    46. plotTree(secondDict[key], cntrPt, str(key))
    47. else:
    48. plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
    49. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    50. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    51. plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    52. def createPlot(inTree):
    53. fig = plt.figure(1, facecolor='white')
    54. fig.clf()
    55. axprops = dict(xticks=[], yticks=[])
    56. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    57. plotTree.totalw = float(getNumLeafs(inTree))
    58. plotTree.totalD = float(getTreeDepth(inTree))
    59. plotTree.xOff = -0.5 / plotTree.totalw
    60. plotTree.yOff = 1.0
    61. plotTree(inTree, (0.5, 1.0), '')
    62. plt.show()