1use std::sync::Arc;
16
17use api::v1::flow::flow_client::FlowClient as PbFlowClient;
18use api::v1::health_check_client::HealthCheckClient;
19use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
20use api::v1::region::region_client::RegionClient as PbRegionClient;
21use api::v1::HealthCheckRequest;
22use arrow_flight::flight_service_client::FlightServiceClient;
23use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
24use parking_lot::RwLock;
25use snafu::{OptionExt, ResultExt};
26use tonic::codec::CompressionEncoding;
27use tonic::transport::Channel;
28
29use crate::load_balance::{LoadBalance, Loadbalancer};
30use crate::{error, Result};
31
32pub struct FlightClient {
33 addr: String,
34 client: FlightServiceClient<Channel>,
35}
36
37impl FlightClient {
38 pub fn addr(&self) -> &str {
39 &self.addr
40 }
41
42 pub fn mut_inner(&mut self) -> &mut FlightServiceClient<Channel> {
43 &mut self.client
44 }
45}
46
47#[derive(Clone, Debug, Default)]
48pub struct Client {
49 inner: Arc<Inner>,
50}
51
52#[derive(Debug, Default)]
53struct Inner {
54 channel_manager: ChannelManager,
55 peers: Arc<RwLock<Vec<String>>>,
56 load_balance: Loadbalancer,
57}
58
59impl Inner {
60 fn with_manager(channel_manager: ChannelManager) -> Self {
61 Self {
62 channel_manager,
63 ..Default::default()
64 }
65 }
66
67 fn set_peers(&self, peers: Vec<String>) {
68 let mut guard = self.peers.write();
69 *guard = peers;
70 }
71
72 fn get_peer(&self) -> Option<String> {
73 let guard = self.peers.read();
74 self.load_balance.get_peer(&guard).cloned()
75 }
76}
77
78impl Client {
79 pub fn new() -> Self {
80 Default::default()
81 }
82
83 pub fn with_urls<U, A>(urls: A) -> Self
84 where
85 U: AsRef<str>,
86 A: AsRef<[U]>,
87 {
88 Self::with_manager_and_urls(ChannelManager::new(), urls)
89 }
90
91 pub fn with_tls_and_urls<U, A>(urls: A, client_tls: ClientTlsOption) -> Result<Self>
92 where
93 U: AsRef<str>,
94 A: AsRef<[U]>,
95 {
96 let channel_config = ChannelConfig::default().client_tls_config(client_tls);
97 let channel_manager = ChannelManager::with_tls_config(channel_config)
98 .context(error::CreateTlsChannelSnafu)?;
99 Ok(Self::with_manager_and_urls(channel_manager, urls))
100 }
101
102 pub fn with_manager_and_urls<U, A>(channel_manager: ChannelManager, urls: A) -> Self
103 where
104 U: AsRef<str>,
105 A: AsRef<[U]>,
106 {
107 let inner = Inner::with_manager(channel_manager);
108 let urls: Vec<String> = urls
109 .as_ref()
110 .iter()
111 .map(|peer| peer.as_ref().to_string())
112 .collect();
113 inner.set_peers(urls);
114 Self {
115 inner: Arc::new(inner),
116 }
117 }
118
119 pub fn start<U, A>(&self, urls: A)
120 where
121 U: AsRef<str>,
122 A: AsRef<[U]>,
123 {
124 let urls: Vec<String> = urls
125 .as_ref()
126 .iter()
127 .map(|peer| peer.as_ref().to_string())
128 .collect();
129
130 self.inner.set_peers(urls);
131 }
132
133 pub fn find_channel(&self) -> Result<(String, Channel)> {
134 let addr = self
135 .inner
136 .get_peer()
137 .context(error::IllegalGrpcClientStateSnafu {
138 err_msg: "No available peer found",
139 })?;
140
141 let channel = self
142 .inner
143 .channel_manager
144 .get(&addr)
145 .context(error::CreateChannelSnafu { addr: &addr })?;
146 Ok((addr, channel))
147 }
148
149 pub fn max_grpc_recv_message_size(&self) -> usize {
150 self.inner
151 .channel_manager
152 .config()
153 .max_recv_message_size
154 .as_bytes() as usize
155 }
156
157 pub fn max_grpc_send_message_size(&self) -> usize {
158 self.inner
159 .channel_manager
160 .config()
161 .max_send_message_size
162 .as_bytes() as usize
163 }
164
165 pub fn make_flight_client(&self) -> Result<FlightClient> {
166 let (addr, channel) = self.find_channel()?;
167
168 let client = FlightServiceClient::new(channel)
169 .max_decoding_message_size(self.max_grpc_recv_message_size())
170 .max_encoding_message_size(self.max_grpc_send_message_size())
171 .accept_compressed(CompressionEncoding::Zstd)
172 .send_compressed(CompressionEncoding::Zstd);
173
174 Ok(FlightClient { addr, client })
175 }
176
177 pub(crate) fn raw_region_client(&self) -> Result<(String, PbRegionClient<Channel>)> {
178 let (addr, channel) = self.find_channel()?;
179 let client = PbRegionClient::new(channel)
180 .max_decoding_message_size(self.max_grpc_recv_message_size())
181 .max_encoding_message_size(self.max_grpc_send_message_size())
182 .accept_compressed(CompressionEncoding::Zstd)
183 .send_compressed(CompressionEncoding::Zstd);
184 Ok((addr, client))
185 }
186
187 pub(crate) fn raw_flow_client(&self) -> Result<(String, PbFlowClient<Channel>)> {
188 let (addr, channel) = self.find_channel()?;
189 let client = PbFlowClient::new(channel)
190 .max_decoding_message_size(self.max_grpc_recv_message_size())
191 .max_encoding_message_size(self.max_grpc_send_message_size())
192 .accept_compressed(CompressionEncoding::Zstd)
193 .send_compressed(CompressionEncoding::Zstd);
194 Ok((addr, client))
195 }
196
197 pub fn make_prometheus_gateway_client(&self) -> Result<PrometheusGatewayClient<Channel>> {
198 let (_, channel) = self.find_channel()?;
199 let client = PrometheusGatewayClient::new(channel)
200 .accept_compressed(CompressionEncoding::Gzip)
201 .accept_compressed(CompressionEncoding::Zstd)
202 .send_compressed(CompressionEncoding::Gzip)
203 .send_compressed(CompressionEncoding::Zstd);
204 Ok(client)
205 }
206
207 pub async fn health_check(&self) -> Result<()> {
208 let (_, channel) = self.find_channel()?;
209 let mut client = HealthCheckClient::new(channel);
210 let _ = client.health_check(HealthCheckRequest {}).await?;
211 Ok(())
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::collections::HashSet;
218
219 use super::Inner;
220 use crate::load_balance::Loadbalancer;
221
222 fn mock_peers() -> Vec<String> {
223 vec![
224 "127.0.0.1:3001".to_string(),
225 "127.0.0.1:3002".to_string(),
226 "127.0.0.1:3003".to_string(),
227 ]
228 }
229
230 #[test]
231 fn test_inner() {
232 let inner = Inner::default();
233
234 assert!(matches!(
235 inner.load_balance,
236 Loadbalancer::Random(crate::load_balance::Random)
237 ));
238 assert!(inner.get_peer().is_none());
239
240 let peers = mock_peers();
241 inner.set_peers(peers.clone());
242 let all: HashSet<String> = peers.into_iter().collect();
243
244 for _ in 0..20 {
245 assert!(all.contains(&inner.get_peer().unwrap()));
246 }
247 }
248}