1use std::sync::Arc;
16
17use api::v1::HealthCheckRequest;
18use api::v1::flow::flow_client::FlowClient as PbFlowClient;
19use api::v1::health_check_client::HealthCheckClient;
20use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
21use api::v1::region::region_client::RegionClient as PbRegionClient;
22use arrow_flight::flight_service_client::FlightServiceClient;
23use common_grpc::channel_manager::{
24 ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
25};
26use parking_lot::RwLock;
27use snafu::{OptionExt, ResultExt};
28use tonic::codec::CompressionEncoding;
29use tonic::transport::Channel;
30
31use crate::load_balance::{LoadBalance, Loadbalancer};
32use crate::{Result, error};
33
34pub struct FlightClient {
35 addr: String,
36 client: FlightServiceClient<Channel>,
37}
38
39impl FlightClient {
40 pub fn addr(&self) -> &str {
41 &self.addr
42 }
43
44 pub fn mut_inner(&mut self) -> &mut FlightServiceClient<Channel> {
45 &mut self.client
46 }
47}
48
49#[derive(Clone, Debug, Default)]
50pub struct Client {
51 inner: Arc<Inner>,
52}
53
54#[derive(Debug, Default)]
55struct Inner {
56 channel_manager: ChannelManager,
57 peers: Arc<RwLock<Vec<String>>>,
58 load_balance: Loadbalancer,
59}
60
61impl Inner {
62 fn with_manager(channel_manager: ChannelManager) -> Self {
63 Self {
64 channel_manager,
65 ..Default::default()
66 }
67 }
68
69 fn set_peers(&self, peers: Vec<String>) {
70 let mut guard = self.peers.write();
71 *guard = peers;
72 }
73
74 fn get_peer(&self) -> Option<String> {
75 let guard = self.peers.read();
76 self.load_balance.get_peer(&guard).cloned()
77 }
78}
79
80impl Client {
81 pub fn new() -> Self {
82 Default::default()
83 }
84
85 pub fn with_urls<U, A>(urls: A) -> Self
86 where
87 U: AsRef<str>,
88 A: AsRef<[U]>,
89 {
90 Self::with_manager_and_urls(ChannelManager::new(), urls)
91 }
92
93 pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
94 where
95 U: AsRef<str>,
96 A: AsRef<[U]>,
97 {
98 let channel_config = ChannelConfig::default().client_tls_config(client_tls);
99 let tls_config = load_tls_config(channel_config.client_tls.as_ref())
100 .context(error::CreateTlsChannelSnafu)?;
101 let channel_manager = ChannelManager::with_config(channel_config, tls_config);
102 Ok(Self::with_manager_and_urls(channel_manager, urls))
103 }
104
105 pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
106 where
107 U: AsRef<str>,
108 A: AsRef<[U]>,
109 {
110 let inner = Inner::with_manager(channel_manager);
111 let urls: Vec<String> = urls
112 .as_ref()
113 .iter()
114 .map(|peer| peer.as_ref().to_string())
115 .collect();
116 inner.set_peers(urls);
117 Self {
118 inner: Arc::new(inner),
119 }
120 }
121
122 pub fn start<U, A>(&self, urls: A)
123 where
124 U: AsRef<str>,
125 A: AsRef<[U]>,
126 {
127 let urls: Vec<String> = urls
128 .as_ref()
129 .iter()
130 .map(|peer| peer.as_ref().to_string())
131 .collect();
132
133 self.inner.set_peers(urls);
134 }
135
136 pub fn find_channel(&self) -> Result<(String, Channel)> {
137 let addr = self
138 .inner
139 .get_peer()
140 .context(error::IllegalGrpcClientStateSnafu {
141 err_msg: "No available peer found",
142 })?;
143
144 let channel = self
145 .inner
146 .channel_manager
147 .get(&addr)
148 .context(error::CreateChannelSnafu { addr: &addr })?;
149 Ok((addr, channel))
150 }
151
152 pub fn max_grpc_recv_message_size(&self) -> usize {
153 self.inner
154 .channel_manager
155 .config()
156 .max_recv_message_size
157 .as_bytes() as usize
158 }
159
160 pub fn max_grpc_send_message_size(&self) -> usize {
161 self.inner
162 .channel_manager
163 .config()
164 .max_send_message_size
165 .as_bytes() as usize
166 }
167
168 pub fn make_flight_client(
169 &self,
170 send_compression: bool,
171 accept_compression: bool,
172 ) -> Result<FlightClient> {
173 let (addr, channel) = self.find_channel()?;
174
175 let mut client = FlightServiceClient::new(channel)
176 .max_decoding_message_size(self.max_grpc_recv_message_size())
177 .max_encoding_message_size(self.max_grpc_send_message_size());
178 if send_compression {
180 client = client.send_compressed(CompressionEncoding::Zstd);
181 }
182 if accept_compression {
183 client = client.accept_compressed(CompressionEncoding::Zstd);
184 }
185
186 Ok(FlightClient { addr, client })
187 }
188
189 pub(crate) fn raw_region_client(&self) -> Result<(String, PbRegionClient<Channel>)> {
190 let (addr, channel) = self.find_channel()?;
191 let client = PbRegionClient::new(channel)
192 .max_decoding_message_size(self.max_grpc_recv_message_size())
193 .max_encoding_message_size(self.max_grpc_send_message_size());
194 Ok((addr, client))
195 }
196
197 pub(crate) fn raw_flow_client(&self) -> Result<(String, PbFlowClient<Channel>)> {
198 let (addr, channel) = self.find_channel()?;
199 let client = PbFlowClient::new(channel)
200 .max_decoding_message_size(self.max_grpc_recv_message_size())
201 .max_encoding_message_size(self.max_grpc_send_message_size())
202 .accept_compressed(CompressionEncoding::Zstd)
203 .send_compressed(CompressionEncoding::Zstd);
204 Ok((addr, client))
205 }
206
207 pub fn make_prometheus_gateway_client(&self) -> Result<PrometheusGatewayClient<Channel>> {
208 let (_, channel) = self.find_channel()?;
209 let client = PrometheusGatewayClient::new(channel)
210 .accept_compressed(CompressionEncoding::Gzip)
211 .accept_compressed(CompressionEncoding::Zstd)
212 .send_compressed(CompressionEncoding::Gzip)
213 .send_compressed(CompressionEncoding::Zstd);
214 Ok(client)
215 }
216
217 pub async fn health_check(&self) -> Result<()> {
218 let (_, channel) = self.find_channel()?;
219 let mut client = HealthCheckClient::new(channel);
220 let _ = client.health_check(HealthCheckRequest {}).await?;
221 Ok(())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use std::collections::HashSet;
228
229 use super::Inner;
230 use crate::load_balance::Loadbalancer;
231
232 fn mock_peers() -> Vec<String> {
233 vec![
234 "127.0.0.1:3001".to_string(),
235 "127.0.0.1:3002".to_string(),
236 "127.0.0.1:3003".to_string(),
237 ]
238 }
239
240 #[test]
241 fn test_inner() {
242 let inner = Inner::default();
243
244 assert!(matches!(
245 inner.load_balance,
246 Loadbalancer::Random(crate::load_balance::Random)
247 ));
248 assert!(inner.get_peer().is_none());
249
250 let peers = mock_peers();
251 inner.set_peers(peers.clone());
252 let all: HashSet<String> = peers.into_iter().collect();
253
254 for _ in 0..20 {
255 assert!(all.contains(&inner.get_peer().unwrap()));
256 }
257 }
258}