为什么要用蓄水池采样

采样问题经常会涉及到以下类似问题:

  1. 从1000000份调查报告中抽1000份进行统计
  2. 从一本很厚的电话簿中抽1000人进行姓氏统计
  3. 从Google搜索”AI”的结果中抽100个进行分析

问题1很容易想到生成1至1000000的随机数,抽取1000个,用算法去重保证采样结果不重复即可。但问题2与问题3的性质与问题1不同,由于数据规模N可能十分巨大,没有办法将其一次全部读入内存,所以存储所有数据再遍历一次来获取其规模这一做法不可取。

采样问题最重要的是保证公平,即对于总样本集合中的所有元素来说,每个元素需具有相等的概率被选择。这里,我们需要使用蓄水池采样。

蓄水池采样是一个数据抽样算法,主要用来解决如下问题:

给定一个数据流,数据流长度N很大,且N直到处理完所有数据之前都不可知,请问如何在只遍历一遍数据(O(N))的情况下,能够随机选取出m个不重复的数据。

这个场景强调了3件事:

  1. 数据流长度N很大且不可知,所以不能一次性存入内存。
  2. 时间复杂度为O(N)。
  3. 随机选取m个数,每个数被选中的概率为m/N。

第1点限制了不能直接取N内的m个随机数,然后按索引取出数据。第2点限制了不能先遍历一遍,然后分块存储数据,再随机选取。第3点是数据选取绝对随机的保证。

算法流程

  • 假设数据序列的规模为 n,需要采样的数量的为 k。
  • 构建一个可容纳 k 个元素的数组,将序列的前 k 个元素放入数组中。
  • 从第 k+1 个元素开始,以 k/n 的概率来决定该元素是否被替换到数组中(数组中的元素被替换的概率是相同的)。
  • 当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本。

证明如下:
image.png

代码实现

  1. int[] reservoir = new int[m];
  2. // init
  3. for (int i = 0; i < reservoir.length; i++)
  4. {
  5. reservoir[i] = dataStream[i];
  6. }
  7. for (int i = m; i < dataStream.length; i++)
  8. {
  9. // 随机获得一个[0, i]内的随机整数
  10. int d = rand.nextInt(i + 1);
  11. // 如果随机整数落在[0, m-1]范围内,则替换蓄水池中的元素
  12. if (d < m)
  13. {
  14. reservoir[d] = dataStream[i];
  15. }
  16. }
  1. import random
  2. class ReservoirSample(object):
  3. def __init__(self, size):
  4. self._size = size
  5. self._counter = 0
  6. self._sample = []
  7. def feed(self, item):
  8. self._counter += 1
  9. # 第i个元素(i <= k),直接进入池中
  10. if len(self._sample) < self._size:
  11. self._sample.append(item)
  12. return self._sample
  13. # 第i个元素(i > k),以k / i的概率进入池中
  14. rand_int = random.randint(1, self._counter)
  15. if rand_int <= self._size:
  16. self._sample[rand_int - 1] = item
  17. return self._sample