client/
client.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::v1::HealthCheckRequest;
18use api::v1::flow::flow_client::FlowClient as PbFlowClient;
19use api::v1::health_check_client::HealthCheckClient;
20use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
21use api::v1::region::region_client::RegionClient as PbRegionClient;
22use arrow_flight::flight_service_client::FlightServiceClient;
23use common_grpc::channel_manager::{
24    ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
25};
26use parking_lot::RwLock;
27use snafu::{OptionExt, ResultExt};
28use tonic::codec::CompressionEncoding;
29use tonic::transport::Channel;
30
31use crate::load_balance::{LoadBalance, Loadbalancer};
32use crate::{Result, error};
33
34pub struct FlightClient {
35    addr: String,
36    client: FlightServiceClient<Channel>,
37}
38
39impl FlightClient {
40    pub fn addr(&self) -> &str {
41        &self.addr
42    }
43
44    pub fn mut_inner(&mut self) -> &mut FlightServiceClient<Channel> {
45        &mut self.client
46    }
47}
48
49#[derive(Clone, Debug, Default)]
50pub struct Client {
51    inner: Arc<Inner>,
52}
53
54#[derive(Debug, Default)]
55struct Inner {
56    channel_manager: ChannelManager,
57    peers: Arc<RwLock<Vec<String>>>,
58    load_balance: Loadbalancer,
59}
60
61impl Inner {
62    fn with_manager(channel_manager: ChannelManager) -> Self {
63        Self {
64            channel_manager,
65            ..Default::default()
66        }
67    }
68
69    fn set_peers(&self, peers: Vec<String>) {
70        let mut guard = self.peers.write();
71        *guard = peers;
72    }
73
74    fn get_peer(&self) -> Option<String> {
75        let guard = self.peers.read();
76        self.load_balance.get_peer(&guard).cloned()
77    }
78}
79
80impl Client {
81    pub fn new() -> Self {
82        Default::default()
83    }
84
85    pub fn with_urls<U, A>(urls: A) -> Self
86    where
87        U: AsRef<str>,
88        A: AsRef<[U]>,
89    {
90        Self::with_manager_and_urls(ChannelManager::new(), urls)
91    }
92
93    pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
94    where
95        U: AsRef<str>,
96        A: AsRef<[U]>,
97    {
98        let channel_config = ChannelConfig::default().client_tls_config(client_tls);
99        let tls_config = load_tls_config(channel_config.client_tls.as_ref())
100            .context(error::CreateTlsChannelSnafu)?;
101        let channel_manager = ChannelManager::with_config(channel_config, tls_config);
102        Ok(Self::with_manager_and_urls(channel_manager, urls))
103    }
104
105    pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
106    where
107        U: AsRef<str>,
108        A: AsRef<[U]>,
109    {
110        let inner = Inner::with_manager(channel_manager);
111        let urls: Vec<String> = urls
112            .as_ref()
113            .iter()
114            .map(|peer| peer.as_ref().to_string())
115            .collect();
116        inner.set_peers(urls);
117        Self {
118            inner: Arc::new(inner),
119        }
120    }
121
122    pub fn start<U, A>(&self, urls: A)
123    where
124        U: AsRef<str>,
125        A: AsRef<[U]>,
126    {
127        let urls: Vec<String> = urls
128            .as_ref()
129            .iter()
130            .map(|peer| peer.as_ref().to_string())
131            .collect();
132
133        self.inner.set_peers(urls);
134    }
135
136    pub fn find_channel(&self) -> Result<(String, Channel)> {
137        let addr = self
138            .inner
139            .get_peer()
140            .context(error::IllegalGrpcClientStateSnafu {
141                err_msg: "No available peer found",
142            })?;
143
144        let channel = self
145            .inner
146            .channel_manager
147            .get(&addr)
148            .context(error::CreateChannelSnafu { addr: &addr })?;
149        Ok((addr, channel))
150    }
151
152    pub fn max_grpc_recv_message_size(&self) -> usize {
153        self.inner
154            .channel_manager
155            .config()
156            .max_recv_message_size
157            .as_bytes() as usize
158    }
159
160    pub fn max_grpc_send_message_size(&self) -> usize {
161        self.inner
162            .channel_manager
163            .config()
164            .max_send_message_size
165            .as_bytes() as usize
166    }
167
168    pub fn make_flight_client(
169        &self,
170        send_compression: bool,
171        accept_compression: bool,
172    ) -> Result<FlightClient> {
173        let (addr, channel) = self.find_channel()?;
174
175        let mut client = FlightServiceClient::new(channel)
176            .max_decoding_message_size(self.max_grpc_recv_message_size())
177            .max_encoding_message_size(self.max_grpc_send_message_size());
178        // todo(hl): support compression methods.
179        if send_compression {
180            client = client.send_compressed(CompressionEncoding::Zstd);
181        }
182        if accept_compression {
183            client = client.accept_compressed(CompressionEncoding::Zstd);
184        }
185
186        Ok(FlightClient { addr, client })
187    }
188
189    pub(crate) fn raw_region_client(&self) -> Result<(String, PbRegionClient<Channel>)> {
190        let (addr, channel) = self.find_channel()?;
191        let client = PbRegionClient::new(channel)
192            .max_decoding_message_size(self.max_grpc_recv_message_size())
193            .max_encoding_message_size(self.max_grpc_send_message_size());
194        Ok((addr, client))
195    }
196
197    pub(crate) fn raw_flow_client(&self) -> Result<(String, PbFlowClient<Channel>)> {
198        let (addr, channel) = self.find_channel()?;
199        let client = PbFlowClient::new(channel)
200            .max_decoding_message_size(self.max_grpc_recv_message_size())
201            .max_encoding_message_size(self.max_grpc_send_message_size())
202            .accept_compressed(CompressionEncoding::Zstd)
203            .send_compressed(CompressionEncoding::Zstd);
204        Ok((addr, client))
205    }
206
207    pub fn make_prometheus_gateway_client(&self) -> Result<PrometheusGatewayClient<Channel>> {
208        let (_, channel) = self.find_channel()?;
209        let client = PrometheusGatewayClient::new(channel)
210            .accept_compressed(CompressionEncoding::Gzip)
211            .accept_compressed(CompressionEncoding::Zstd)
212            .send_compressed(CompressionEncoding::Gzip)
213            .send_compressed(CompressionEncoding::Zstd);
214        Ok(client)
215    }
216
217    pub async fn health_check(&self) -> Result<()> {
218        let (_, channel) = self.find_channel()?;
219        let mut client = HealthCheckClient::new(channel);
220        let _ = client.health_check(HealthCheckRequest {}).await?;
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use std::collections::HashSet;
228
229    use super::Inner;
230    use crate::load_balance::Loadbalancer;
231
232    fn mock_peers() -> Vec<String> {
233        vec![
234            "127.0.0.1:3001".to_string(),
235            "127.0.0.1:3002".to_string(),
236            "127.0.0.1:3003".to_string(),
237        ]
238    }
239
240    #[test]
241    fn test_inner() {
242        let inner = Inner::default();
243
244        assert!(matches!(
245            inner.load_balance,
246            Loadbalancer::Random(crate::load_balance::Random)
247        ));
248        assert!(inner.get_peer().is_none());
249
250        let peers = mock_peers();
251        inner.set_peers(peers.clone());
252        let all: HashSet<String> = peers.into_iter().collect();
253
254        for _ in 0..20 {
255            assert!(all.contains(&inner.get_peer().unwrap()));
256        }
257    }
258}