作者: 吴翱翔 / 后期编辑:张汉东


原文: 缓存解决动态规划难题

分享下 leetcode 困难题停在原地的方案数
不断推敲和优化逐步通过题目的过程

看到这种不同路径求方案总数,很容易想到 unique_path 这道动态规划入门题,

这题跟 unique_path 一样也是「求从起点到终点不同行走路径的方案总数」,自然想到用动态规划去实现

无记忆化的搜索

由于动态规划迭代解法的状态表达较难抽象,于是我先写出更简单动态规划递归的无记忆化搜索版本

递归的结束条件

那么递归的结束条件显然是「剩余步数」为 0

解答的更新条件

方案数的更新条件则是 剩余步数为 0 且 当前位置也是 0,这时候可以将方案数+1

递归函数的入参

首先需要当前位置和当前剩余步数两个”可变”的入参

再需要一个”常数”表达最大可前往的位置,一旦移动到数组的右边界,下一步就只能原地走或向左走

最后需要一个已发现方案总数的可变指针,用来更新解答集

递归搜索的决策层

只能在数组范围 [0, arr_len] 行走,行走方向 原地不动、向左、向右 三种

  1. 如果当前坐标是 0, 则只能 原地不动 或 向右
  2. 如果当前坐标是 arr_len-1,则只能 原地不动 或 向左
  3. 其余情况的行走方向决策则是 原地不动 或 向左 或 向右

无记忆化搜索代码

  1. fn num_ways_dfs(cur_position: i32, remain_steps: i32, max_position: i32, plans_count: &mut u32) {
  2. if remain_steps == 0 {
  3. if cur_position == 0 {
  4. // panicked at 'attempt to add with overflow'
  5. *plans_count += 1;
  6. }
  7. return;
  8. }
  9. // 剪枝: 走的太远不可能移动回原点的情况
  10. if cur_position > remain_steps {
  11. return;
  12. }
  13. // 做决策
  14. // 决策: 原地不动
  15. num_ways_dfs(cur_position, remain_steps-1, max_position, plans_count);
  16. if cur_position == 0 {
  17. // 只能向右
  18. num_ways_dfs(cur_position+1, remain_steps-1, max_position, plans_count);
  19. } else if cur_position == max_position {
  20. // 只能向左
  21. num_ways_dfs(cur_position-1, remain_steps-1, max_position, plans_count);
  22. } else {
  23. num_ways_dfs(cur_position+1, remain_steps-1, max_position, plans_count);
  24. num_ways_dfs(cur_position-1, remain_steps-1, max_position, plans_count);
  25. }
  26. }
  27. fn num_ways_dfs_entrance(steps: i32, arr_len: i32) -> i32 {
  28. let mut plans_count = 0;
  29. num_ways_dfs(0, steps, arr_len-1, &mut plans_count);
  30. (plans_count % (10_u32.pow(9)+7)) as i32
  31. }

虽然我加上了递归的剪枝条件,但是 leetcode 上只过了 1/3 的测试用例便在 (27,7) 这个测试用例上超时了

不仅如此,更新方案总数时还出现 u32 溢出的问题,我粗略估算下该函数的时间复杂度是 O(3^n) 指数级别的时间复杂度,其中 n 为剩余步数

非线性递归导致超时?

所谓线性递归大概指递归的决策层只有一个分支,或者说递归搜索树只有一个分支

像我上述代码的决策层有 向左/向右/原地不动 三种决策的就显然是个非线性递归,通常都很慢需要剪枝或记忆化才能提速

记忆化搜索

斐波那契递归的记忆化

斐波那契递归解法也是个典型的非线性递归

假设斐波那契数列的第 n 项为 fib(n),很容易想到斐波那契数列的 fib(3) 的搜索树可以展开为:

fib(3)=fib(2)+fib(1)=(fib(1)+fib(0))+fib(1)=2*fib(1)+fib(0)

我们发现 fib(1) 被重复计算了两次,所以业界有种「记忆化搜索」的优化策略

具体实现是定义一个 HashMap,key 为递归函数的入参,value 为该入参情况的计算结果

例如计算 fib(3) 的过程中,第一次遇到 fib(1) 这个入参时进行计算,并将计算结果存入 HashMap 中,

第二次递归调用 fib(1) 时可以直接从 HashMap 中查表取结果而不需要「重复计算」

这种优化思路有点像缓存,相信一个无状态的函数同样的入参一定能得到同样的结果,所以第二次遇到同样的入参时直接拿上一次相同入参的计算结果去返回

记忆化搜索的实现条件

我第一版的递归搜索代码中,方案总数作为可变指针参数来传入,这种写法「不能用记忆化搜索优化」

因函数 fn num_ways_dfs(cur_position: i32, remain_steps: i32, max_position: i32, plans_count: &mut u32)

并没有返回值,我无法实现一个 key 为入参,value 为该入参的上次计算结果返回值这样的记忆化缓存

逆向思维: 自下而上的递归

假设 f(pos,steps)=plans 表示从原点出发,当前位置 pos,剩余步数为 steps 的方案总数 plans

很容易想到 状态转移规律: f(0,0)=f(0,1)+f(1,1)

也就是终点是原点的前一个状态只能是: 前一个位置是 0 然后选择原地不动 或 前一个位置是 1 然后向左走

然后参考「数学归纳法」可以按照相同的规律将 f(0,1) 和 f(1,1) 也展开成子项,直到展开成 f(0, steps) 也就是起点

记忆化搜索的函数签名

  1. struct NumWaysHelper {
  2. max_position: i32,
  3. steps: i32,
  4. /// memo
  5. cache: std::collections::HashMap<(i32, i32), u64>
  6. }
  7. impl NumWaysHelper {
  8. fn dfs(&mut self, cur_pos: i32, remain_steps: i32) -> u64 {
  9. // TODO 递归结束条件
  10. let mut plans_count = 0;
  11. // 做决策/状态转移
  12. // 上一步是: 原地不动
  13. // TODO
  14. if cur_pos == 0 {
  15. // 上一步是: 向左
  16. // TODO
  17. } else if cur_pos == self.max_position {
  18. // 上一步是: 向左
  19. // TODO
  20. } else {
  21. // 上一步是: 向左或向右
  22. // TODO
  23. }
  24. self.cache.insert((cur_pos, remain_steps), plans_count);
  25. plans_count
  26. }
  27. }

缓存的写入

其中最关键的就是 self.cache.insert((cur_pos, remain_steps), plans_count); 这行

函数在 return 前先把(当前入参,返回值)这对计算结果「缓存到 HashMap」中

利用缓存避免重复计算

  1. let mut plans_count = 0;
  2. // 做决策/状态转移
  3. // 上一步是: 原地不动
  4. if let Some(plans) = self.cache.get(&(cur_pos, remain_steps+1)) {
  5. plans_count += *plans;
  6. } else {
  7. plans_count += self.dfs(cur_pos, remain_steps+1);
  8. }

因为递归调用的开销挺大的,以上上一步是原地不动的决策分支中,一旦发现之前运算过 (cur_pos, remain_steps+1) 的入参情况就直接取缓存中的上次计算结果(因为函数是无状态的,相同的入参一定能得到相同的结果)

记忆化搜索版本的实现

  1. struct NumWaysHelper {
  2. max_position: i32,
  3. steps: i32,
  4. cache: std::collections::HashMap<(i32, i32), u64>
  5. }
  6. impl NumWaysHelper {
  7. fn dfs(&mut self, cur_pos: i32, remain_steps: i32) -> u64 {
  8. if remain_steps == self.steps {
  9. if cur_pos == 0 {
  10. return 1;
  11. } else {
  12. // 只有从起点出发的方案才是有效的方案,其余方案都不可取(0)
  13. return 0;
  14. }
  15. }
  16. let mut plans_count = 0;
  17. // 做决策/状态转移
  18. // 共同的决策分支-上一步是: 原地不动
  19. plans_count += self.calc_plans_from_cache(cur_pos, remain_steps+1);
  20. if cur_pos == 0 {
  21. // 上一步是: 向左
  22. plans_count += self.calc_plans_from_cache(cur_pos+1, remain_steps+1);
  23. } else if cur_pos == self.max_position {
  24. // 上一步是: 向右
  25. plans_count += self.calc_plans_from_cache(cur_pos-1, remain_steps+1);
  26. } else {
  27. // 上一步是: 向左或向右
  28. plans_count += self.calc_plans_from_cache(cur_pos+1, remain_steps+1);
  29. plans_count += self.calc_plans_from_cache(cur_pos-1, remain_steps+1);
  30. }
  31. self.cache.insert((cur_pos, remain_steps), plans_count);
  32. plans_count
  33. }
  34. fn calc_plans_from_cache(&mut self, last_pos: i32, last_remain_steps: i32) -> u64 {
  35. if let Some(plans) = self.cache.get(&(last_pos, last_remain_steps)) {
  36. *plans
  37. } else {
  38. self.dfs(last_pos, last_remain_steps)
  39. }
  40. }
  41. }

本题缓存与数据库缓存的异同

MySQL 为了提高短时间相同 Query 的查询速度,会将查询的 SQL 语句计算哈希和对应的查询结果存入 Query Cache

在缓存的有效期内,遇到第二个相同的 SQL 查询就能直接从缓存中获取上次查询结果进行返回

MySQL 将 SQL 语句进行哈希是不是跟我们这题将递归调用的入参元祖作为 key 存入 HashMap 类似?

除了数据库,graphql 和 dataloader 也是大量用到了缓存,也是将查询计算 hash 作为 key 存入 HashMap 中

可以了解下 dataloader 这个 crate 的 源码
是如何进行缓存以及解决 N+1 查询的问题的

解决溢出错误

我们记忆化搜索的解法通过了80%的测试用例,但是在输入参数特别大时就出错了

  1. 输入:
  2. 93
  3. 85
  4. 输出:
  5. 468566822
  6. 预期结果:
  7. 623333920

看到期待值不对很多人以为「是不是我算法写错了」?

其实不是,一般这种入参很大的都是整数溢出的问题,leetcode 的 Rust 用的是溢出时自动 wrapping 的 release 编译

所谓 wrapping 值得就例如 0_u8.wrapping_sub(1)==255,0_u8 减 1 会下溢成 255

由于 leetcode 的题目描述中也提示了 方案总数可能会很大,所以每次加法都需要取模避免 i32 溢出

我也尝试修改 type PlansCount = i32,就算方案数用 u128 存储也会溢出,所以还是老老实实加法后取模

题解完整代码及测试代码

  1. type PlansCount = i32;
  2. struct NumWaysHelper {
  3. max_position: i32,
  4. steps: i32,
  5. /// memo
  6. cache: std::collections::HashMap<(i32, i32), PlansCount>,
  7. }
  8. impl NumWaysHelper {
  9. /// leetcode rust version not support const_fn pow
  10. const MOD: PlansCount = 1_000_000_007;
  11. fn dfs(&mut self, cur_pos: i32, remain_steps: i32) -> PlansCount {
  12. // 递归结束条件
  13. if remain_steps == self.steps {
  14. if cur_pos == 0 {
  15. return 1;
  16. }
  17. // 只有从起点出发的方案才是有效的方案,其余方案都不可取(0)
  18. return 0;
  19. }
  20. // 做决策/状态转移
  21. // 共同的决策分支: 上一步-原地不动
  22. let mut plans_count = self.calc_plans_from_cache(cur_pos, remain_steps + 1);
  23. if cur_pos == 0 {
  24. // 上一步是: 向左
  25. plans_count += self.calc_plans_from_cache(cur_pos + 1, remain_steps + 1);
  26. } else if cur_pos == self.max_position {
  27. // 上一步是: 向右
  28. plans_count += self.calc_plans_from_cache(cur_pos - 1, remain_steps + 1);
  29. } else {
  30. // 上一步是: 向左或向右
  31. plans_count += self.calc_plans_from_cache(cur_pos + 1, remain_steps + 1);
  32. plans_count =
  33. plans_count % Self::MOD + self.calc_plans_from_cache(cur_pos - 1, remain_steps + 1);
  34. }
  35. self.cache.insert((cur_pos, remain_steps), plans_count);
  36. plans_count
  37. }
  38. /// can't use map_or_else, reason: Error: closure requires unique access to `self` but `self` is already borrowed
  39. #[allow(clippy::option_if_let_else)]
  40. fn calc_plans_from_cache(&mut self, last_pos: i32, last_remain_steps: i32) -> PlansCount {
  41. (if let Some(plans) = self.cache.get(&(last_pos, last_remain_steps)) {
  42. *plans
  43. } else {
  44. self.dfs(last_pos, last_remain_steps)
  45. }) % Self::MOD
  46. }
  47. }
  48. fn num_ways_dfs_entrance(steps: i32, arr_len: i32) -> i32 {
  49. let mut helper = NumWaysHelper {
  50. max_position: arr_len - 1,
  51. steps,
  52. cache: std::collections::HashMap::new(),
  53. };
  54. helper.dfs(0, 0) % NumWaysHelper::MOD
  55. }
  56. #[test]
  57. fn test_num_ways() {
  58. const TEST_CASES: [(i32, i32, i32); 4] = [(93, 85, 623333920), (3, 2, 4), (2, 4, 2), (4, 2, 8)];
  59. for (steps, arr_len, plans_count) in TEST_CASES {
  60. assert_eq!(num_ways_dfs_entrance(steps, arr_len), plans_count);
  61. }
  62. }

完整源码: https://github.com/pymongo/leetcode-rust/blob/b6f0101a50a70512c12dd33333bfa535307ac40e/src/dp/number_of_ways_to_stay_in_the_same_place_after_some_steps.rs#L277

小结下逐步优化题解的过程

首先是根据题目意思写出了无缓存/无记忆化的从搜索树自上而下的递归解法,实现的过程中逐步理解了动态规划的状态转移方程,

进而写出了带缓存的深度优先搜索解法,解决了溢出等小问题后终于通过了

image.png

为什么不是 dp[i][j] 的动态规划写法

有读者可能疑惑,为什么 leetcode 这题官方题解或绝大部分题解都用 dp[i][j] 这种写法

我的题解运行速度比 dp[i][j] 的写法慢得多

首先我要明确一点动态规划其实是有两种主流的写法的,一种就是常见的 dp[i][j] 迭代写法去填表

另一种就是我介绍的递归记忆化/缓存化搜索

dp[i][j] 写法的最大毛病就是「可读性极差」,构思难度高

我以前写的动态规划代码,过五个月再看完全忘记 i 和 j 表达什么意思了

dfs(cur_position: i32, remain_steps: i32) 这种写法不比 dp[i][j] 的可读性强很多?

记忆化搜索另一种好处就是,可以快速写出简单的无缓存版本,再慢慢优化解决超时问题,而迭代的动态规划写法起步就很难

所以我个人更推荐大家多练习记忆化搜索解动态规划,这种借鉴数据库缓存的思路还是很简单的,面试中遇到不熟悉的动态规划题可以先试着用记忆化搜索去解决