1. package com.jideos.jnotes.utils;
    2. import lombok.AllArgsConstructor;
    3. import lombok.Data;
    4. import lombok.NoArgsConstructor;
    5. import java.util.*;
    6. import java.util.concurrent.ThreadLocalRandom;
    7. /**
    8. * @author 杨胖胖
    9. */
    10. public class WeightRandom<T> {
    11. private final List<T> items = new ArrayList<>();
    12. private double[] weights;
    13. public WeightRandom(List<ItemWithWeight<T>> itemsWithWeight) {
    14. this.calWeights(itemsWithWeight);
    15. }
    16. /**
    17. * 计算权重,初始化或者重新定义权重时使用
    18. */
    19. public void calWeights(List<ItemWithWeight<T>> itemsWithWeight) {
    20. items.clear();
    21. // 计算权重总和
    22. double originWeightSum = 0;
    23. for (ItemWithWeight<T> itemWithWeight : itemsWithWeight) {
    24. double weight = itemWithWeight.getWeight();
    25. if (weight <= 0) {
    26. continue;
    27. }
    28. items.add(itemWithWeight.getItem());
    29. if (Double.isInfinite(weight)) {
    30. weight = 10000.0D;
    31. }
    32. if (Double.isNaN(weight)) {
    33. weight = 1.0D;
    34. }
    35. originWeightSum += weight;
    36. }
    37. // 计算每个item的实际权重比例
    38. double[] actualWeightRatios = new double[items.size()];
    39. int index = 0;
    40. for (ItemWithWeight<T> itemWithWeight : itemsWithWeight) {
    41. double weight = itemWithWeight.getWeight();
    42. if (weight <= 0) {
    43. continue;
    44. }
    45. actualWeightRatios[index++] = weight / originWeightSum;
    46. }
    47. // 计算每个item的权重范围
    48. // 权重范围起始位置
    49. weights = new double[items.size()];
    50. double weightRangeStartPos = 0;
    51. for (int i = 0; i < index; i++) {
    52. weights[i] = weightRangeStartPos + actualWeightRatios[i];
    53. weightRangeStartPos += actualWeightRatios[i];
    54. }
    55. }
    56. /**
    57. * 基于权重随机算法选择
    58. */
    59. public T choose() {
    60. double random = ThreadLocalRandom.current().nextDouble();
    61. int index = Arrays.binarySearch(weights, random);
    62. if (index < 0) {
    63. index = -index - 1;
    64. } else {
    65. return items.get(index);
    66. }
    67. if (index < weights.length && random < weights[index]) {
    68. return items.get(index);
    69. }
    70. // 通常不会走到这里,为了保证能得到正确的返回,这里随便返回一个
    71. return items.get(0);
    72. }
    73. @Data
    74. @AllArgsConstructor
    75. @NoArgsConstructor
    76. public static class ItemWithWeight<T> {
    77. T item;
    78. double weight;
    79. }
    80. }