1use std::sync::Arc;
16
17use api::v1::flow::flow_client::FlowClient as PbFlowClient;
18use api::v1::health_check_client::HealthCheckClient;
19use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
20use api::v1::region::region_client::RegionClient as PbRegionClient;
21use api::v1::HealthCheckRequest;
22use arrow_flight::flight_service_client::FlightServiceClient;
23use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
24use parking_lot::RwLock;
25use snafu::{OptionExt, ResultExt};
26use tonic::codec::CompressionEncoding;
27use tonic::transport::Channel;
28
29use crate::load_balance::{LoadBalance, Loadbalancer};
30use crate::{error, Result};
31
32pub struct FlightClient {
33 addr: String,
34 client: FlightServiceClient<Channel>,
35}
36
37impl FlightClient {
38 pub fn addr(&self) -> &str {
39 &self.addr
40 }
41
42 pub fn mut_inner(&mut self) -> &mut FlightServiceClient<Channel> {
43 &mut self.client
44 }
45}
46
47#[derive(Clone, Debug, Default)]
48pub struct Client {
49 inner: Arc<Inner>,
50}
51
52#[derive(Debug, Default)]
53struct Inner {
54 channel_manager: ChannelManager,
55 peers: Arc<RwLock<Vec<String>>>,
56 load_balance: Loadbalancer,
57}
58
59impl Inner {
60 fn with_manager(channel_manager: ChannelManager) -> Self {
61 Self {
62 channel_manager,
63 ..Default::default()
64 }
65 }
66
67 fn set_peers(&self, peers: Vec<String>) {
68 let mut guard = self.peers.write();
69 *guard = peers;
70 }
71
72 fn get_peer(&self) -> Option<String> {
73 let guard = self.peers.read();
74 self.load_balance.get_peer(&guard).cloned()
75 }
76}
77
78impl Client {
79 pub fn new() -> Self {
80 Default::default()
81 }
82
83 pub fn with_urls<U, A>(urls: A) -> Self
84 where
85 U: AsRef<str>,
86 A: AsRef<[U]>,
87 {
88 Self::with_manager_and_urls(ChannelManager::new(), urls)
89 }
90
91 pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
92 where
93 U: AsRef<str>,
94 A: AsRef<[U]>,
95 {
96 let channel_config = ChannelConfig::default().client_tls_config(client_tls);
97 let channel_manager = ChannelManager::with_tls_config(channel_config)
98 .context(error::CreateTlsChannelSnafu)?;
99 Ok(Self::with_manager_and_urls(channel_manager, urls))
100 }
101
102 pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
103 where
104 U: AsRef<str>,
105 A: AsRef<[U]>,
106 {
107 let inner = Inner::with_manager(channel_manager);
108 let urls: Vec<String> = urls
109 .as_ref()
110 .iter()
111 .map(|peer| peer.as_ref().to_string())
112 .collect();
113 inner.set_peers(urls);
114 Self {
115 inner: Arc::new(inner),
116 }
117 }
118
119 pub fn start<U, A>(&self, urls: A)
120 where
121 U: AsRef<str>,
122 A: AsRef<[U]>,
123 {
124 let urls: Vec<String> = urls
125 .as_ref()
126 .iter()
127 .map(|peer| peer.as_ref().to_string())
128 .collect();
129
130 self.inner.set_peers(urls);
131 }
132
133 pub fn find_channel(&self) -> Result<(String, Channel)> {
134 let addr = self
135 .inner
136 .get_peer()
137 .context(error::IllegalGrpcClientStateSnafu {
138 err_msg: "No available peer found",
139 })?;
140
141 let channel = self
142 .inner
143 .channel_manager
144 .get(&addr)
145 .context(error::CreateChannelSnafu { addr: &addr })?;
146 Ok((addr, channel))
147 }
148
149 pub fn max_grpc_recv_message_size(&self) -> usize {
150 self.inner
151 .channel_manager
152 .config()
153 .max_recv_message_size
154 .as_bytes() as usize
155 }
156
157 pub fn max_grpc_send_message_size(&self) -> usize {
158 self.inner
159 .channel_manager
160 .config()
161 .max_send_message_size
162 .as_bytes() as usize
163 }
164
165 pub fn make_flight_client(
166 &self,
167 send_compression: bool,
168 accept_compression: bool,
169 ) -> Result<FlightClient> {
170 let (addr, channel) = self.find_channel()?;
171
172 let mut client = FlightServiceClient::new(channel)
173 .max_decoding_message_size(self.max_grpc_recv_message_size())
174 .max_encoding_message_size(self.max_grpc_send_message_size());
175 if send_compression {
177 client = client.send_compressed(CompressionEncoding::Zstd);
178 }
179 if accept_compression {
180 client = client.accept_compressed(CompressionEncoding::Zstd);
181 }
182
183 Ok(FlightClient { addr, client })
184 }
185
186 pub(crate) fn raw_region_client(&self) -> Result<(String, PbRegionClient<Channel>)> {
187 let (addr, channel) = self.find_channel()?;
188 let client = PbRegionClient::new(channel)
189 .max_decoding_message_size(self.max_grpc_recv_message_size())
190 .max_encoding_message_size(self.max_grpc_send_message_size());
191 Ok((addr, client))
192 }
193
194 pub(crate) fn raw_flow_client(&self) -> Result<(String, PbFlowClient<Channel>)> {
195 let (addr, channel) = self.find_channel()?;
196 let client = PbFlowClient::new(channel)
197 .max_decoding_message_size(self.max_grpc_recv_message_size())
198 .max_encoding_message_size(self.max_grpc_send_message_size())
199 .accept_compressed(CompressionEncoding::Zstd)
200 .send_compressed(CompressionEncoding::Zstd);
201 Ok((addr, client))
202 }
203
204 pub fn make_prometheus_gateway_client(&self) -> Result<PrometheusGatewayClient<Channel>> {
205 let (_, channel) = self.find_channel()?;
206 let client = PrometheusGatewayClient::new(channel)
207 .accept_compressed(CompressionEncoding::Gzip)
208 .accept_compressed(CompressionEncoding::Zstd)
209 .send_compressed(CompressionEncoding::Gzip)
210 .send_compressed(CompressionEncoding::Zstd);
211 Ok(client)
212 }
213
214 pub async fn health_check(&self) -> Result<()> {
215 let (_, channel) = self.find_channel()?;
216 let mut client = HealthCheckClient::new(channel);
217 let _ = client.health_check(HealthCheckRequest {}).await?;
218 Ok(())
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use std::collections::HashSet;
225
226 use super::Inner;
227 use crate::load_balance::Loadbalancer;
228
229 fn mock_peers() -> Vec<String> {
230 vec![
231 "127.0.0.1:3001".to_string(),
232 "127.0.0.1:3002".to_string(),
233 "127.0.0.1:3003".to_string(),
234 ]
235 }
236
237 #[test]
238 fn test_inner() {
239 let inner = Inner::default();
240
241 assert!(matches!(
242 inner.load_balance,
243 Loadbalancer::Random(crate::load_balance::Random)
244 ));
245 assert!(inner.get_peer().is_none());
246
247 let peers = mock_peers();
248 inner.set_peers(peers.clone());
249 let all: HashSet<String> = peers.into_iter().collect();
250
251 for _ in 0..20 {
252 assert!(all.contains(&inner.get_peer().unwrap()));
253 }
254 }
255}