https://leetcode-cn.com/problems/super-egg-drop/
n 层楼,k个鸡蛋,如何确定那一层楼扔下鸡蛋会碎,最少尝试多少次。

法一

状态定义:鸡蛋落地问题 - 图1为 n 层楼扔下 k 个鸡蛋的最少次数。
转移:尝试枚举第一个扔鸡蛋的楼层 鸡蛋落地问题 - 图2

  • 如果碎掉,问题变为 鸡蛋落地问题 - 图3
  • 如果没有碎掉:问题变为 鸡蛋落地问题 - 图4

状态转移:鸡蛋落地问题 - 图5
复杂度: 鸡蛋落地问题 - 图6
由于随着 n 的增加(k 一定),答案增加,所以第一部分递增,第二部分递减。
我们期望找到 鸡蛋落地问题 - 图7的最小值,那么就是找到这两个函数的交点。
找到最大的 鸡蛋落地问题 - 图8鸡蛋落地问题 - 图9
找到最小的 鸡蛋落地问题 - 图10鸡蛋落地问题 - 图11
那么最终答案在 鸡蛋落地问题 - 图12鸡蛋落地问题 - 图13中做选择即可。

以下代码采用递推法,如果改成记忆化,对于LC测试会快一些。

  1. class Solution {
  2. public:
  3. int superEggDrop(int k, int n) {
  4. vector<vector<int>> d(n + 1, vector<int>(k + 1, 1E9));
  5. for(int j = 0; j <= k; j++) d[0][j] = 0;
  6. for(int j = 1; j <= k; j++) d[1][j] = 1;
  7. for(int i = 2; i <= n; i++){
  8. for(int j = 1; j <= k; j++){
  9. int l = 1, r = i;
  10. while(l < r) {
  11. int m = (l + r + 1) / 2;
  12. if(d[m - 1][j - 1] <= d[i - m][j]) l = m;
  13. else r = m - 1;
  14. }
  15. int x0 = l;
  16. l = 1, r = i;
  17. while(l < r) {
  18. int m = (l + r) / 2;
  19. if(d[m - 1][j - 1] >= d[i - m][j]) r = m;
  20. else l = m + 1;
  21. }
  22. int x1 = l;
  23. int val1 = max(d[x0 - 1][j - 1], d[i - x0][j]) + 1;
  24. int val2 = max(d[x1 - 1][j - 1], d[i - x1][j]) + 1;
  25. d[i][j] = min(d[i][j], min(val1, val2));
  26. }
  27. }
  28. return d[n][k];
  29. }
  30. };

法二

鸡蛋落地问题 - 图14
注意观察,固定 k,随着 n 的增加,x 应当是增加的。
所以利用决策单调性求解即可,复杂度 鸡蛋落地问题 - 图15

  1. class Solution {
  2. public:
  3. int superEggDrop(int k, int n) {
  4. vector<int> d(n + 1);
  5. iota(d.begin(), d.end(), 0);
  6. for(int j = 2; j <= k; j++){
  7. vector<int> f(n + 1);
  8. int x = 1;
  9. f[1] = 1;
  10. for(int i = 2; i <= n; i++){
  11. while(x < i && max(d[x - 1], f[i - x]) >= max(d[x], f[i - x - 1])) x++;
  12. f[i] = max(d[x - 1], f[i - x]) + 1;
  13. }
  14. d = move(f);
  15. }
  16. return d[n];
  17. }
  18. };

法三

鸡蛋落地问题 - 图16表示 k 个鸡蛋尝试 t 次能够最多测试多少层。
随便扔一下,如果碎了,那么说明下面最多有 鸡蛋落地问题 - 图17层,如果没碎,说明上面最多可以有 鸡蛋落地问题 - 图18层。
所以 鸡蛋落地问题 - 图19

  1. class Solution {
  2. public:
  3. int superEggDrop(int k, int n) {
  4. if(n == 1) return 1;
  5. vector<vector<int>> f(n + 1, vector<int>(k + 1));
  6. for(int i = 1; i <= n; i++) {
  7. for(int j = 1; j <= k; j++){
  8. f[i][j] = 1 + f[i - 1][j] + f[i - 1][j - 1];
  9. }
  10. if(f[i][k] >= n) return i;
  11. }
  12. return n;
  13. }
  14. };