client/
load_balance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use enum_dispatch::enum_dispatch;
16use rand::seq::IndexedRandom;
17
18#[enum_dispatch]
19pub trait LoadBalance {
20    fn get_peer<'a>(&self, peers: &'a [String]) -> Option<&'a String>;
21}
22
23#[enum_dispatch(LoadBalance)]
24#[derive(Debug)]
25pub enum Loadbalancer {
26    Random,
27}
28
29impl Default for Loadbalancer {
30    fn default() -> Self {
31        Loadbalancer::from(Random)
32    }
33}
34
35#[derive(Debug)]
36pub struct Random;
37
38impl LoadBalance for Random {
39    fn get_peer<'a>(&self, peers: &'a [String]) -> Option<&'a String> {
40        peers.choose(&mut rand::rng())
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use std::collections::HashSet;
47
48    use super::{LoadBalance, Random};
49
50    #[test]
51    fn test_random_lb() {
52        let peers = vec![
53            "127.0.0.1:3001".to_string(),
54            "127.0.0.1:3002".to_string(),
55            "127.0.0.1:3003".to_string(),
56            "127.0.0.1:3004".to_string(),
57        ];
58        let all: HashSet<String> = peers.clone().into_iter().collect();
59
60        let random = Random;
61        for _ in 0..100 {
62            let peer = random.get_peer(&peers).unwrap();
63            assert!(all.contains(peer));
64        }
65    }
66}