在做一道题时,如何判断应该用前序还是中序还是后序遍历的框架?
根据题意,思考一个二叉树节点需要做什么,到底用什么遍历顺序就清楚了。

652. 寻找重复的子树

简单解释下题目,输入是一棵二叉树的根节点root,返回的是一个列表,里面装着若干个二叉树节点,这些节点对应的子树在原二叉树中是存在重复的。
image.png
那上面这张图举例,首先节点 4 本身可以作为一颗子树,并且二叉树中多个节点 4;还存在两颗以节点 2 为根的重复子树,那么我们返回的 List 中就应该有两个TreeNode,值分别为 4 和 2。

拿到题,还是老套路,先思考对于某一个节点,它应该做什么?
image.png
比如,我们站在节点 2 上,现在我们想知道以节点 2 为根的子树是否重复,是否应该被加入到结果列表中,我们需要知道哪些信息?
首先,我们需要知道下面两组信息:

  1. 以我为根的这课二叉树(子树)长什么样?
  2. 以其他节点为根的子树长什么样?

以我为根的这课二叉树(子树)长什么样?

后序遍历

看到这个问题,就可以判断这道题要用「后序遍历」框架来解决。

  1. void traverse(TreeNode root) {
  2. traverse(root.left);
  3. traverse(root.right);
  4. /* 解法代码的位置 */
  5. }

理由是,我要知道以自己为根的子树长什么样,我要先知道我的左右子树长什么样,最后再加上自己,就构成了整颗子树的样子。

这里我们可以举一个例子,计算一颗二叉树有多少个节点。

int count(TreeNode root) {
    if (root == null) {
        return 0;
    }
    // 先算出左右子树有多少节点
    int left = count(root.left);
    int right = count(root.right);
    /* 后序遍历代码位置 */
    // 加上自己,就是整棵二叉树的节点数
    int res = left + right + 1;
    return res;
}

这就是标准的后序遍历框架,既然我们要计算二叉树有多少个节点,我当然要先知道以我为根的二叉树的左右子树有多少节点,最后再加上我这个根节点,不就是二叉树的节点数嘛?

如何描述二叉树结构

现在我们明确了要用「后序遍历」,但是如何描述一颗二叉树的模样呢?在之前序列化篇(297、剑指37、剑指48)中,二叉树的前序/中序/后序的遍历结果可以描述二叉树的结构。所以我们可以通过拼接字符串的方式把二叉树序列化:

String traverse(TreeNode root) {
    // 对于空节点,可以用一个特殊字符表示
    if (root == null) {
        return "#";
    }
    // 将左右子树序列化成字符串
    String left = traverse(root.left);
    String right = traverse(root.right);
    /* 后序遍历代码位置 */
    // 左右子树加上自己,就是以自己为根的二叉树序列化结果
    String subTree = left + "," + right + "," + root.val;
    return subTree;
}

我们用非数字的特殊符 # 表示空指针,并且用字符 , 分隔每个二叉树节点值,这属于序列化二叉树的套路了。
注意我们subTree是按照左子树、右子树、根节点这样的顺序拼接字符串,也就是后序遍历顺序。你完全可以按照前序或者中序的顺序拼接字符串,因为这里只是为了描述一棵二叉树的样子,什么顺序不重要。
这样,我们第一个问题就解决了,对于每个节点,递归函数中的subTree变量就可以描述以该节点为根的二叉树。

以其他节点为根的子树长什么样

现在我们解决第二个问题,我知道了自己长啥样,怎么知道别人长啥样?这样我才能知道有没有其他子树跟我重复对吧。
这很简单呀,我们借助一个外部数据结构,让每个节点把自己子树的序列化结果存进去,这样,对于每个节点,不就可以知道有没有其他节点的子树和自己重复了么?
这里使用 HashMap,记录每一颗子树的出现次数:

最终代码

// 记录所有子树以及出现的次数
HashMap<String, Integer> memo = new HashMap<>();
// 记录重复的子树根节点
LinkedList<TreeNode> res = new LinkedList<>();

/* 主函数 */
List<TreeNode> findDuplicateSubtrees(TreeNode root) {
    traverse(root);
    return res;
}

/* 辅助函数 */
String traverse(TreeNode root) {
    if (root == null) {
        return "#";
    }

    String left = traverse(root.left);
    String right = traverse(root.right);

    String subTree = left + "," + right+ "," + root.val;

    int freq = memo.getOrDefault(subTree, 0);
    // 多次重复也只会被加入结果集一次
    if (freq == 1) {
        res.add(root);
    }
    // 给子树对应的出现次数加一
    memo.put(subTree, freq + 1);
    return subTree;
}