client/
client.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::v1::flow::flow_client::FlowClient as PbFlowClient;
18use api::v1::health_check_client::HealthCheckClient;
19use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
20use api::v1::region::region_client::RegionClient as PbRegionClient;
21use api::v1::HealthCheckRequest;
22use arrow_flight::flight_service_client::FlightServiceClient;
23use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
24use parking_lot::RwLock;
25use snafu::{OptionExt, ResultExt};
26use tonic::codec::CompressionEncoding;
27use tonic::transport::Channel;
28
29use crate::load_balance::{LoadBalance, Loadbalancer};
30use crate::{error, Result};
31
32pub struct FlightClient {
33    addr: String,
34    client: FlightServiceClient<Channel>,
35}
36
37impl FlightClient {
38    pub fn addr(&self) -> &str {
39        &self.addr
40    }
41
42    pub fn mut_inner(&mut self) -> &mut FlightServiceClient<Channel> {
43        &mut self.client
44    }
45}
46
47#[derive(Clone, Debug, Default)]
48pub struct Client {
49    inner: Arc<Inner>,
50}
51
52#[derive(Debug, Default)]
53struct Inner {
54    channel_manager: ChannelManager,
55    peers: Arc<RwLock<Vec<String>>>,
56    load_balance: Loadbalancer,
57}
58
59impl Inner {
60    fn with_manager(channel_manager: ChannelManager) -> Self {
61        Self {
62            channel_manager,
63            ..Default::default()
64        }
65    }
66
67    fn set_peers(&self, peers: Vec<String>) {
68        let mut guard = self.peers.write();
69        *guard = peers;
70    }
71
72    fn get_peer(&self) -> Option<String> {
73        let guard = self.peers.read();
74        self.load_balance.get_peer(&guard).cloned()
75    }
76}
77
78impl Client {
79    pub fn new() -> Self {
80        Default::default()
81    }
82
83    pub fn with_urls<U, A>(urls: A) -> Self
84    where
85        U: AsRef<str>,
86        A: AsRef<[U]>,
87    {
88        Self::with_manager_and_urls(ChannelManager::new(), urls)
89    }
90
91    pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
92    where
93        U: AsRef<str>,
94        A: AsRef<[U]>,
95    {
96        let channel_config = ChannelConfig::default().client_tls_config(client_tls);
97        let channel_manager = ChannelManager::with_tls_config(channel_config)
98            .context(error::CreateTlsChannelSnafu)?;
99        Ok(Self::with_manager_and_urls(channel_manager, urls))
100    }
101
102    pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
103    where
104        U: AsRef<str>,
105        A: AsRef<[U]>,
106    {
107        let inner = Inner::with_manager(channel_manager);
108        let urls: Vec<String> = urls
109            .as_ref()
110            .iter()
111            .map(|peer| peer.as_ref().to_string())
112            .collect();
113        inner.set_peers(urls);
114        Self {
115            inner: Arc::new(inner),
116        }
117    }
118
119    pub fn start<U, A>(&self, urls: A)
120    where
121        U: AsRef<str>,
122        A: AsRef<[U]>,
123    {
124        let urls: Vec<String> = urls
125            .as_ref()
126            .iter()
127            .map(|peer| peer.as_ref().to_string())
128            .collect();
129
130        self.inner.set_peers(urls);
131    }
132
133    pub fn find_channel(&self) -> Result<(String, Channel)> {
134        let addr = self
135            .inner
136            .get_peer()
137            .context(error::IllegalGrpcClientStateSnafu {
138                err_msg: "No available peer found",
139            })?;
140
141        let channel = self
142            .inner
143            .channel_manager
144            .get(&addr)
145            .context(error::CreateChannelSnafu { addr: &addr })?;
146        Ok((addr, channel))
147    }
148
149    pub fn max_grpc_recv_message_size(&self) -> usize {
150        self.inner
151            .channel_manager
152            .config()
153            .max_recv_message_size
154            .as_bytes() as usize
155    }
156
157    pub fn max_grpc_send_message_size(&self) -> usize {
158        self.inner
159            .channel_manager
160            .config()
161            .max_send_message_size
162            .as_bytes() as usize
163    }
164
165    pub fn make_flight_client(&self) -> Result<FlightClient> {
166        let (addr, channel) = self.find_channel()?;
167
168        let client = FlightServiceClient::new(channel)
169            .max_decoding_message_size(self.max_grpc_recv_message_size())
170            .max_encoding_message_size(self.max_grpc_send_message_size())
171            .accept_compressed(CompressionEncoding::Zstd)
172            .send_compressed(CompressionEncoding::Zstd);
173
174        Ok(FlightClient { addr, client })
175    }
176
177    pub(crate) fn raw_region_client(&self) -> Result<(String, PbRegionClient<Channel>)> {
178        let (addr, channel) = self.find_channel()?;
179        let client = PbRegionClient::new(channel)
180            .max_decoding_message_size(self.max_grpc_recv_message_size())
181            .max_encoding_message_size(self.max_grpc_send_message_size())
182            .accept_compressed(CompressionEncoding::Zstd)
183            .send_compressed(CompressionEncoding::Zstd);
184        Ok((addr, client))
185    }
186
187    pub(crate) fn raw_flow_client(&self) -> Result<(String, PbFlowClient<Channel>)> {
188        let (addr, channel) = self.find_channel()?;
189        let client = PbFlowClient::new(channel)
190            .max_decoding_message_size(self.max_grpc_recv_message_size())
191            .max_encoding_message_size(self.max_grpc_send_message_size())
192            .accept_compressed(CompressionEncoding::Zstd)
193            .send_compressed(CompressionEncoding::Zstd);
194        Ok((addr, client))
195    }
196
197    pub fn make_prometheus_gateway_client(&self) -> Result<PrometheusGatewayClient<Channel>> {
198        let (_, channel) = self.find_channel()?;
199        let client = PrometheusGatewayClient::new(channel)
200            .accept_compressed(CompressionEncoding::Gzip)
201            .accept_compressed(CompressionEncoding::Zstd)
202            .send_compressed(CompressionEncoding::Gzip)
203            .send_compressed(CompressionEncoding::Zstd);
204        Ok(client)
205    }
206
207    pub async fn health_check(&self) -> Result<()> {
208        let (_, channel) = self.find_channel()?;
209        let mut client = HealthCheckClient::new(channel);
210        let _ = client.health_check(HealthCheckRequest {}).await?;
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::collections::HashSet;
218
219    use super::Inner;
220    use crate::load_balance::Loadbalancer;
221
222    fn mock_peers() -> Vec<String> {
223        vec![
224            "127.0.0.1:3001".to_string(),
225            "127.0.0.1:3002".to_string(),
226            "127.0.0.1:3003".to_string(),
227        ]
228    }
229
230    #[test]
231    fn test_inner() {
232        let inner = Inner::default();
233
234        assert!(matches!(
235            inner.load_balance,
236            Loadbalancer::Random(crate::load_balance::Random)
237        ));
238        assert!(inner.get_peer().is_none());
239
240        let peers = mock_peers();
241        inner.set_peers(peers.clone());
242        let all: HashSet<String> = peers.into_iter().collect();
243
244        for _ in 0..20 {
245            assert!(all.contains(&inner.get_peer().unwrap()));
246        }
247    }
248}