meta_srv/selector/weighted_choose.rs
1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rand::rng;
16use rand::seq::IndexedRandom;
17use snafu::ResultExt;
18
19use crate::error;
20use crate::error::Result;
21
22/// A common trait for weighted balance algorithm.
23pub trait WeightedChoose<Item>: Send + Sync {
24 /// The method will choose one item.
25 fn choose_one(&mut self) -> Result<Item>;
26
27 /// The method will choose multiple items.
28 ///
29 /// ## Note
30 ///
31 /// - Returns less than `amount` items if the weight_array is not enough.
32 /// - The returned items cannot be duplicated.
33 fn choose_multiple(&mut self, amount: usize) -> Result<Vec<Item>>;
34
35 /// Returns the length of the weight_array.
36 fn len(&self) -> usize;
37
38 /// Returns whether the weight_array is empty.
39 fn is_empty(&self) -> bool {
40 self.len() == 0
41 }
42}
43
44/// The struct represents a weighted item.
45#[derive(Debug, Clone, PartialEq)]
46pub struct WeightedItem<Item> {
47 pub item: Item,
48 pub weight: f64,
49}
50
51/// A implementation of weighted balance: random weighted choose.
52///
53/// The algorithm is as follows:
54///
55/// ```text
56/// random value
57/// ─────────────────────────────────▶
58/// │
59/// ▼
60/// ┌─────────────────┬─────────┬──────────────────────┬─────┬─────────────────┐
61/// │element_0 │element_1│element_2 │... │element_n │
62/// └─────────────────┴─────────┴──────────────────────┴─────┴─────────────────┘
63/// ```
64pub struct RandomWeightedChoose<Item> {
65 items: Vec<WeightedItem<Item>>,
66}
67
68impl<Item> RandomWeightedChoose<Item> {
69 pub fn new(items: Vec<WeightedItem<Item>>) -> Self {
70 Self { items }
71 }
72}
73
74impl<Item> Default for RandomWeightedChoose<Item> {
75 fn default() -> Self {
76 Self {
77 items: Vec::default(),
78 }
79 }
80}
81
82impl<Item> WeightedChoose<Item> for RandomWeightedChoose<Item>
83where
84 Item: Clone + Send + Sync,
85{
86 fn choose_one(&mut self) -> Result<Item> {
87 // unwrap safety: whether weighted_index is none has been checked before.
88 let item = self
89 .items
90 .choose_weighted(&mut rng(), |item| item.weight)
91 .context(error::ChooseItemsSnafu)?
92 .item
93 .clone();
94 Ok(item)
95 }
96
97 fn choose_multiple(&mut self, amount: usize) -> Result<Vec<Item>> {
98 let amount = amount.min(self.items.iter().filter(|item| item.weight > 0.0).count());
99
100 Ok(self
101 .items
102 .choose_multiple_weighted(&mut rng(), amount, |item| item.weight)
103 .context(error::ChooseItemsSnafu)?
104 .cloned()
105 .map(|item| item.item)
106 .collect::<Vec<_>>())
107 }
108
109 fn len(&self) -> usize {
110 self.items.len()
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::{RandomWeightedChoose, WeightedChoose, WeightedItem};
117
118 #[test]
119 fn test_random_weighted_choose() {
120 let mut choose = RandomWeightedChoose::new(vec![
121 WeightedItem {
122 item: 1,
123 weight: 100.0,
124 },
125 WeightedItem {
126 item: 2,
127 weight: 0.0,
128 },
129 ]);
130
131 for _ in 0..100 {
132 let ret = choose.choose_one().unwrap();
133 assert_eq!(1, ret);
134 }
135
136 for _ in 0..100 {
137 let ret = choose.choose_multiple(3).unwrap();
138 assert_eq!(vec![1], ret);
139 }
140 }
141}