meta_srv/selector/
weighted_choose.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rand::rng;
16use rand::seq::IndexedRandom;
17use snafu::ResultExt;
18
19use crate::error;
20use crate::error::Result;
21
22/// A common trait for weighted balance algorithm.
23pub trait WeightedChoose<Item>: Send + Sync {
24    /// The method will choose one item.
25    fn choose_one(&mut self) -> Result<Item>;
26
27    /// The method will choose multiple items.
28    ///
29    /// ## Note
30    ///
31    /// - Returns less than `amount` items if the weight_array is not enough.
32    /// - The returned items cannot be duplicated.
33    fn choose_multiple(&mut self, amount: usize) -> Result<Vec<Item>>;
34
35    /// Returns the length of the weight_array.
36    fn len(&self) -> usize;
37
38    /// Returns whether the weight_array is empty.
39    fn is_empty(&self) -> bool {
40        self.len() == 0
41    }
42}
43
44/// The struct represents a weighted item.
45#[derive(Debug, Clone, PartialEq)]
46pub struct WeightedItem<Item> {
47    pub item: Item,
48    pub weight: f64,
49}
50
51/// A implementation of weighted balance: random weighted choose.
52///
53/// The algorithm is as follows:
54///
55/// ```text
56///           random value
57/// ─────────────────────────────────▶
58///                                  │
59///                                  ▼
60/// ┌─────────────────┬─────────┬──────────────────────┬─────┬─────────────────┐
61/// │element_0        │element_1│element_2             │...  │element_n        │
62/// └─────────────────┴─────────┴──────────────────────┴─────┴─────────────────┘
63/// ```
64pub struct RandomWeightedChoose<Item> {
65    items: Vec<WeightedItem<Item>>,
66}
67
68impl<Item> RandomWeightedChoose<Item> {
69    pub fn new(items: Vec<WeightedItem<Item>>) -> Self {
70        Self { items }
71    }
72}
73
74impl<Item> Default for RandomWeightedChoose<Item> {
75    fn default() -> Self {
76        Self {
77            items: Vec::default(),
78        }
79    }
80}
81
82impl<Item> WeightedChoose<Item> for RandomWeightedChoose<Item>
83where
84    Item: Clone + Send + Sync,
85{
86    fn choose_one(&mut self) -> Result<Item> {
87        // unwrap safety: whether weighted_index is none has been checked before.
88        let item = self
89            .items
90            .choose_weighted(&mut rng(), |item| item.weight)
91            .context(error::ChooseItemsSnafu)?
92            .item
93            .clone();
94        Ok(item)
95    }
96
97    fn choose_multiple(&mut self, amount: usize) -> Result<Vec<Item>> {
98        let amount = amount.min(self.items.iter().filter(|item| item.weight > 0.0).count());
99
100        Ok(self
101            .items
102            .choose_multiple_weighted(&mut rng(), amount, |item| item.weight)
103            .context(error::ChooseItemsSnafu)?
104            .cloned()
105            .map(|item| item.item)
106            .collect::<Vec<_>>())
107    }
108
109    fn len(&self) -> usize {
110        self.items.len()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::{RandomWeightedChoose, WeightedChoose, WeightedItem};
117
118    #[test]
119    fn test_random_weighted_choose() {
120        let mut choose = RandomWeightedChoose::new(vec![
121            WeightedItem {
122                item: 1,
123                weight: 100.0,
124            },
125            WeightedItem {
126                item: 2,
127                weight: 0.0,
128            },
129        ]);
130
131        for _ in 0..100 {
132            let ret = choose.choose_one().unwrap();
133            assert_eq!(1, ret);
134        }
135
136        for _ in 0..100 {
137            let ret = choose.choose_multiple(3).unwrap();
138            assert_eq!(vec![1], ret);
139        }
140    }
141}