蓄水池算法(Reservoir Sampling)

这个算法真的很奇妙,它的核心是一个数学证明。外延,或者说应用场景是:

  1. 蓄水池算法 - 图1,也就是从大小为n的样本集中随机取k个不同的样本
  2. 流式数据,或者说无法直接根据索引拿到数据(更加不可能一遍加载到内存)

算法描述

算法的描述其实很简单:维基百科:水塘抽样

  1. 问题描述:从包含n个不同的项目的集合S中随机选取k个不同的样本。
  2. 算法:
  3. S中取首k个放入[水塘]中
  4. 对每个S[j]项(j>=k,数组从0开始):
  5. 随机产生一个范围从0j的整数r
  6. r<k则把水塘中的第r项换成S[j]项
  7. 最后得到的水塘就是抽样结果

这个算法保证了每一项最后可能存在于水塘中的概率都是一样的。

单看算法,你肯定不知道为什么是等概率,其实数学证明并不难,请看下面的证明:

数学证明

我们把样本分为两类:

  1. 一类是首k个,它们一开始就在水塘中
  2. 一类是其他,它们一开始并不在水塘中

我们发现两个简单的逻辑:

  1. 对于水塘中的样本,只要随机数不选到该样本,该样本就不会被替换
  2. 水塘的某个项一旦被替换,就不可能再回到水塘,不会出现被替换掉,然后再回到水塘的局面,这样就保证了问题不会进一步变得复杂。所以:**某个项被保留的概率 = 被选中到水塘的概率 后续不被替换的概率

分类讨论,首k个样本最终存在于水塘中的概率,和其余样本最终存在于水塘中的概率:

  1. 首k个样本,随便选一个做研究对象。被选中到水塘的概率为:1。(数组从1开始)从j=k+1开始考虑替换,第一次不被替换的概率是蓄水池算法 - 图2,第二次不被替换的概率是蓄水池算法 - 图3,第三次…,一直到最后一次不被替换的概率是蓄水池算法 - 图4
    所以该项被保留的概率 = 蓄水池算法 - 图5
  2. 一开始不在水塘中的那一部分,随便选一个做研究对象。被选中到水塘的概率为:蓄水池算法 - 图6,后续不被替换的概率蓄水池算法 - 图7,一直到蓄水池算法 - 图8
    所以该项被保留的概率 = 蓄水池算法 - 图9

到此我们就证明了所以样本最终存在于水塘中的概率都是蓄水池算法 - 图10,这也完全符合了我们的数学期望。

代码

弄个流式数据我们这里没有条件,只能用伪代码模拟一下:

  1. public Data[] reservoirSampling(int k, DataStream dataStream){
  2. Data[] reservoir = new int[k];
  3. // init pool
  4. for(int i=0;i<reservoir.length;i++){
  5. reservoir[i] = dataStream.getCurrentData();
  6. dataStream.toNext();
  7. }
  8. Random random = new Random();
  9. for(int i=k;!dataStream.isFinish();i++){
  10. int d = random.nextInt(i+1);
  11. if(d<k){
  12. reservoir[d] = dataStream.getCurrentData();
  13. }
  14. dataStream.toNext();
  15. }
  16. return reservoir;
  17. }