🚀 原文链接

A joint article about causality and interpretable machine learning with Eleanor Dillon, Jacob LaRiviere, Scott Lundberg, Jonathan Roth, and Vasilis Syrgkanis from Microsoft.

Predictive machine learning models like XGBoost become even more powerful when paired with interpretability tools like SHAP. These tools identify the most informative relationships between the input features and the predicted outcome, which is useful for explaining what the model is doing, getting stakeholder buy-in, and diagnosing potential problems. It is tempting to take this analysis one step further and assume that interpretation tools can also identify what features decision makers should manipulate if they want to change outcomes in the future. However, in this article, we discuss how using predictive models to guide this kind of policy choice can often be misleading(误导).

The reason relates to the fundamental difference between correlation and causation(因果关系). SHAP makes transparent the correlations picked up by predictive ML models. But making correlations transparent does not make them causal! All predictive models implicitly assume that everyone will keep behaving the same way in the future, and therefore correlation patterns will stay constant. To understand what happens if someone starts behaving differently, we need to build causal models, which requires making assumptions and using the tools of causal analysis.

1. A subscriber retention example

Imagine we are tasked with building a model that predicts whether a customer will renew their product subscription. Let’s assume that after a bit of digging(挖) we manage to get eight features which are important for predicting churn(预测客户流失): customer discount, ad spending, customer’s monthly usage, last upgrade, bugs reported by a customer, interactions with a customer, sales calls with a customer, and macroeconomic(宏观经济) activity. We then use those features to train a basic XGBoost model to predict if a customer will renew their subscription when it expires:

  1. import numpy as np
  2. import pandas as pd
  3. import scipy.stats
  4. import sklearn
  5. import xgboost
  6. class FixableDataFrame(pd.DataFrame):
  7. """ Helper class for manipulating generative models.
  8. """
  9. def __init__(self, *args, fixed={}, **kwargs):
  10. self.__dict__["__fixed_var_dictionary"] = fixed
  11. super(FixableDataFrame, self).__init__(*args, **kwargs)
  12. def __setitem__(self, key, value):
  13. out = super(FixableDataFrame, self).__setitem__(key, value)
  14. if isinstance(key, str) and key in self.__dict__["__fixed_var_dictionary"]:
  15. out = super(FixableDataFrame, self).__setitem__(key, self.__dict__["__fixed_var_dictionary"][key])
  16. return out
  17. # generate the data
  18. def generator(n, fixed={}, seed=0):
  19. """ The generative model for our subscriber retention example.
  20. """
  21. if seed is not None:
  22. np.random.seed(seed)
  23. X = FixableDataFrame(fixed=fixed)
  24. # the number of sales calls made to this customer
  25. X["Sales calls"] = np.random.uniform(0, 4, size=(n,)).round()
  26. # the number of sales calls made to this customer
  27. X["Interactions"] = X["Sales calls"] + np.random.poisson(0.2, size=(n,))
  28. # the health of the regional economy this customer is a part of
  29. X["Economy"] = np.random.uniform(0, 1, size=(n,))
  30. # the time since the last product upgrade when this customer came up for renewal
  31. X["Last upgrade"] = np.random.uniform(0, 20, size=(n,))
  32. # how much the user perceives that they need the product
  33. X["Product need"] = (X["Sales calls"] * 0.1 + np.random.normal(0, 1, size=(n,)))
  34. # the fractional discount offered to this customer upon renewal
  35. X["Discount"] = ((1-scipy.special.expit(X["Product need"])) * 0.5 + 0.5 * np.random.uniform(0, 1, size=(n,))) / 2
  36. # What percent of the days in the last period was the user actively using the product
  37. X["Monthly usage"] = scipy.special.expit(X["Product need"] * 0.3 + np.random.normal(0, 1, size=(n,)))
  38. # how much ad money we spent per user targeted at this user (or a group this user is in)
  39. X["Ad spend"] = X["Monthly usage"] * np.random.uniform(0.99, 0.9, size=(n,)) + (X["Last upgrade"] < 1) + (X["Last upgrade"] < 2)
  40. # how many bugs did this user encounter in the since their last renewal
  41. X["Bugs faced"] = np.array([np.random.poisson(v*2) for v in X["Monthly usage"]])
  42. # how many bugs did the user report?
  43. X["Bugs reported"] = (X["Bugs faced"] * scipy.special.expit(X["Product need"])).round()
  44. # did the user renew?
  45. X["Did renew"] = scipy.special.expit(7 * (
  46. 0.18 * X["Product need"] \
  47. + 0.08 * X["Monthly usage"] \
  48. + 0.1 * X["Economy"] \
  49. + 0.05 * X["Discount"] \
  50. + 0.05 * np.random.normal(0, 1, size=(n,)) \
  51. + 0.05 * (1 - X['Bugs faced'] / 20) \
  52. + 0.005 * X["Sales calls"] \
  53. + 0.015 * X["Interactions"] \
  54. + 0.1 / (X["Last upgrade"]/4 + 0.25)
  55. + X["Ad spend"] * 0.0 - 0.45
  56. ))
  57. # in real life we would make a random draw to get either 0 or 1 for if the
  58. # customer did or did not renew. but here we leave the label as the probability
  59. # so that we can get less noise in our plots. Uncomment this line to get
  60. # noiser causal effect lines but the same basic results
  61. X["Did renew"] = scipy.stats.bernoulli.rvs(X["Did renew"])
  62. return X
  63. def user_retention_dataset():
  64. """ The observed data for model training.
  65. """
  66. n = 10000
  67. X_full = generator(n)
  68. y = X_full["Did renew"]
  69. X = X_full.drop(["Did renew", "Product need", "Bugs faced"], axis=1)
  70. return X, y
  71. def fit_xgboost(X, y):
  72. """ Train an XGBoost model with early stopping.
  73. """
  74. X_train,X_test,y_train,y_test = sklearn.model_selection.train_test_split(X, y)
  75. dtrain = xgboost.DMatrix(X_train, label=y_train)
  76. dtest = xgboost.DMatrix(X_test, label=y_test)
  77. model = xgboost.train(
  78. { "eta": 0.001, "subsample": 0.5, "max_depth": 2, "objective": "reg:logistic"}, dtrain, num_boost_round=200000,
  79. evals=((dtest, "test"),), early_stopping_rounds=20, verbose_eval=False
  80. )
  81. return model
  82. X, y = user_retention_dataset()
  83. model = fit_xgboost(X, y)

Once we have our XGBoost customer retention(留存) model in hand, we can begin exploring what it has learned with an interpretability tool like SHAP. We start by plotting the global importance of each feature in the model:

  1. import shap
  2. explainer = shap.Explainer(model)
  3. shap_values = explainer(X)
  4. clust = shap.utils.hclust(X, y, linkage="complete")
  5. shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

image.png
This bar plot shows that the discount offered, ad spend, and number of bugs reported are the top three factors driving the model’s prediction of customer retention. This is interesting and at first glance looks reasonable. The bar plot also includes a feature redundancy(冗余) clustering which we will use later.

However, when we dig deeper and look at how changing the value of each feature impacts the model’s prediction, we find some unintuitive patterns. SHAP scatter plots show how changing the value of a feature impacts the model’s prediction of renewal(更新) probabilities. If the blue dots follow an increasing pattern, this means that the larger the feature, the higher is the model’s predicted renewal probability.

  1. shap.plots.scatter(shap_values[:, "Bugs reported"], title="SHAP value\n(higher means more likely to renew)")
  2. shap.plots.scatter(shap_values[:, "Discount"], title="SHAP value\n(higher means more likely to renew)")

image.pngimage.png

2. Prediction tasks versus causal tasks

The scatter plots show some surprising findings: Users who report more bugs are more likely to renew! Users with larger discounts are less likely to renew!

We triple-check our code and data pipelines to rule out a bug, then talk to some business partners who offer an intuitive explanation: Users with high usage who value the product are more likely to report bugs and to renew their subscriptions. The sales force tends to give high discounts to customers they think are less likely to be interested in the product, and these customers have higher churn.

Are these at-first counter-intuitive(反直觉的) relationships in the model a problem? That depends on what our goal is!
_
Our original goal for this model was to predict customer retention, which is useful for projects like estimating future revenue for financial planning. Since users reporting more bugs are in fact more likely to renew, capturing this relationship in the model is helpful for prediction. As long as our model has good fit out-of-sample, we should be able to provide finance with a good prediction, and therefore shouldn’t worry about the direction of this relationship in the model.

This is an example of a class of tasks called prediction tasks. In a prediction task, the goal is to predict an outcome Y (e.g. renewals) given a set of features X. A key component of a prediction exercise is that we only care that the prediction model(X) is close to Y in data distributions similar to our training set. A simple correlation between X and Y can be helpful for these types of predictions.

However, suppose a second team picks up our prediction model with the new goal of determining what actions our company can take to retain more customers. This team cares a lot about how each X feature relates to Y, not just in our training distribution, but the counterfactual(反事实的) scenario(情景) produced when the world changes. In that use case, it is no longer sufficient to identify a stable correlation between variables; this team wants to know whether manipulating feature X will cause a change in Y. Picture the face of the chief of engineering when you tell him that you want him to introduce new bugs to increase customer renewals(续约)!

This is an example of a class of tasks called causal tasks. In a causal task, we want to know how changing an aspect of the world X (e.g bugs reported) affects an outcome Y (renewals). In this case, it’s critical to know whether changing X causes an increase in Y, or whether the relationship in the data is merely correlational.

3. The challenges of estimating causal effects

A useful tool to understanding causal relationships is writing down a causal graph of the data generating process we’re interested in. A causal graph of our example illustrates why the robust predictive relationships picked up by our XGBoost customer retention model differ from the causal relationships of interest to the team that wants to plan interventions(干预) to increase retention. This graph is just a summary of the true data generating mechanism (which is defined above). Solid ovals represent features that we observe, while dashed ovals represent hidden features that we don’t measure. Each feature is a function of all the features with an arrow to it, plus some random effects.

  1. import graphviz
  2. names = [
  3. "Bugs reported", "Monthly usage", "Sales calls", "Economy",
  4. "Discount", "Last upgrade", "Ad spend", "Interactions"
  5. ]
  6. g = graphviz.Digraph()
  7. for name in names:
  8. g.node(name, fontsize="10")
  9. g.node("Product need", style="dashed", fontsize="10")
  10. g.node("Bugs faced", style="dashed", fontsize="10")
  11. g.node("Did renew", style="filled", fontsize="10")
  12. g.edge("Product need", "Did renew")
  13. g.edge("Product need", "Discount")
  14. g.edge("Product need", "Bugs reported")
  15. g.edge("Product need", "Monthly usage")
  16. g.edge("Discount", "Did renew")
  17. g.edge("Monthly usage", "Bugs faced")
  18. g.edge("Monthly usage", "Did renew")
  19. g.edge("Monthly usage", "Ad spend")
  20. g.edge("Economy", "Did renew")
  21. g.edge("Sales calls", "Did renew")
  22. g.edge("Sales calls", "Product need")
  23. g.edge("Sales calls", "Interactions")
  24. g.edge("Interactions", "Did renew")
  25. g.edge("Bugs faced", "Did renew")
  26. g.edge("Bugs faced", "Bugs reported")
  27. g.edge("Last upgrade", "Did renew")
  28. g.edge("Last upgrade", "Ad spend")
  29. g

image.svg