相关资源下载github:https://github.com/parrt/dtreeviz
安装:
pip install dtreeviz # install dtreeviz for sklearn
pip install dtreeviz[xgboost] # install XGBoost related dependency
pip install dtreeviz[pyspark] # install pyspark related dependency
pip install dtreeviz[lightgbm] # install LightGBM related dependency
pip install dtreeviz[tensorflow_decision_forests] # install tensorflow_decision_forests related dependency
pip install dtreeviz[all] # install all related dependencies

例子1:鸢尾属植物数据集

代码:

  1. from sklearn.datasets import load_iris
  2. from sklearn.tree import DecisionTreeClassifier
  3. import dtreeviz
  4. iris = load_iris()
  5. X = iris.data
  6. y = iris.target
  7. clf = DecisionTreeClassifier(max_depth=4)
  8. clf.fit(X, y)
  9. viz_model = dtreeviz.model(clf,
  10. X_train=X, y_train=y,
  11. feature_names=iris.feature_names,
  12. target_name='iris',
  13. class_names=iris.target_names)
  14. v = viz_model.view() # render as SVG into internal object
  15. v.show() # pop up window
  16. # v.save("/tmp/iris.svg") # optionally save as svg

输出:

image.png