1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::mapref::entry::Entry;
22use dashmap::DashMap;
23use lazy_static::lazy_static;
24use snafu::{OptionExt, ResultExt};
25use tokio_util::sync::CancellationToken;
26use tonic::transport::{
27 Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
28};
29use tower::Service;
30
31use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
32
33const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
34pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
35pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
36pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
37pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38
39lazy_static! {
40 static ref ID: AtomicU64 = AtomicU64::new(0);
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct ChannelManager {
45 inner: Arc<Inner>,
46}
47
48#[derive(Debug)]
49struct Inner {
50 id: u64,
51 config: ChannelConfig,
52 client_tls_config: Option<ClientTlsConfig>,
53 pool: Arc<Pool>,
54 channel_recycle_started: AtomicBool,
55 cancel: CancellationToken,
56}
57
58impl Default for Inner {
59 fn default() -> Self {
60 Self::with_config(ChannelConfig::default())
61 }
62}
63
64impl Drop for Inner {
65 fn drop(&mut self) {
66 self.cancel.cancel();
68 }
69}
70
71impl Inner {
72 fn with_config(config: ChannelConfig) -> Self {
73 let id = ID.fetch_add(1, Ordering::Relaxed);
74 let pool = Arc::new(Pool::default());
75 let cancel = CancellationToken::new();
76
77 Self {
78 id,
79 config,
80 client_tls_config: None,
81 pool,
82 channel_recycle_started: AtomicBool::new(false),
83 cancel,
84 }
85 }
86}
87
88impl ChannelManager {
89 pub fn new() -> Self {
90 Default::default()
91 }
92
93 pub fn with_config(config: ChannelConfig) -> Self {
94 let inner = Inner::with_config(config);
95 Self {
96 inner: Arc::new(inner),
97 }
98 }
99
100 pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
101 let mut inner = Inner::with_config(config.clone());
102
103 let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
105 msg: "no config input",
106 })?;
107
108 let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
109 .context(InvalidConfigFilePathSnafu)?;
110 let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
111 let client_cert = std::fs::read_to_string(path_config.client_cert_path)
112 .context(InvalidConfigFilePathSnafu)?;
113 let client_key = std::fs::read_to_string(path_config.client_key_path)
114 .context(InvalidConfigFilePathSnafu)?;
115 let client_identity = Identity::from_pem(client_cert, client_key);
116
117 inner.client_tls_config = Some(
118 ClientTlsConfig::new()
119 .ca_certificate(server_root_ca_cert)
120 .identity(client_identity),
121 );
122
123 Ok(Self {
124 inner: Arc::new(inner),
125 })
126 }
127
128 pub fn config(&self) -> &ChannelConfig {
129 &self.inner.config
130 }
131
132 fn pool(&self) -> &Arc<Pool> {
133 &self.inner.pool
134 }
135
136 pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
137 self.trigger_channel_recycling();
138
139 let addr = addr.as_ref();
140 if let Some(inner_ch) = self.pool().get(addr) {
142 return Ok(inner_ch);
143 }
144
145 let entry = match self.pool().entry(addr.to_string()) {
147 Entry::Occupied(entry) => {
148 entry.get().increase_access();
149 entry.into_ref()
150 }
151 Entry::Vacant(entry) => {
152 let endpoint = self.build_endpoint(addr)?;
153 let inner_channel = endpoint.connect_lazy();
154
155 let channel = Channel {
156 channel: inner_channel,
157 access: AtomicUsize::new(1),
158 use_default_connector: true,
159 };
160 entry.insert(channel)
161 }
162 };
163 Ok(entry.channel.clone())
164 }
165
166 pub fn reset_with_connector<C>(
167 &self,
168 addr: impl AsRef<str>,
169 connector: C,
170 ) -> Result<InnerChannel>
171 where
172 C: Service<Uri> + Send + 'static,
173 C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
174 C::Future: Send + 'static,
175 Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
176 {
177 let addr = addr.as_ref();
178 let endpoint = self.build_endpoint(addr)?;
179 let inner_channel = endpoint.connect_with_connector_lazy(connector);
180 let channel = Channel {
181 channel: inner_channel.clone(),
182 access: AtomicUsize::new(1),
183 use_default_connector: false,
184 };
185 self.pool().put(addr, channel);
186
187 Ok(inner_channel)
188 }
189
190 pub fn retain_channel<F>(&self, f: F)
191 where
192 F: FnMut(&String, &mut Channel) -> bool,
193 {
194 self.pool().retain_channel(f);
195 }
196
197 fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
198 let http_prefix = if self.inner.client_tls_config.is_some() {
199 "https"
200 } else {
201 "http"
202 };
203
204 let mut endpoint =
205 Endpoint::new(format!("{http_prefix}://{addr}")).context(CreateChannelSnafu)?;
206
207 if let Some(dur) = self.config().timeout {
208 endpoint = endpoint.timeout(dur);
209 }
210 if let Some(dur) = self.config().connect_timeout {
211 endpoint = endpoint.connect_timeout(dur);
212 }
213 if let Some(limit) = self.config().concurrency_limit {
214 endpoint = endpoint.concurrency_limit(limit);
215 }
216 if let Some((limit, dur)) = self.config().rate_limit {
217 endpoint = endpoint.rate_limit(limit, dur);
218 }
219 if let Some(size) = self.config().initial_stream_window_size {
220 endpoint = endpoint.initial_stream_window_size(size);
221 }
222 if let Some(size) = self.config().initial_connection_window_size {
223 endpoint = endpoint.initial_connection_window_size(size);
224 }
225 if let Some(dur) = self.config().http2_keep_alive_interval {
226 endpoint = endpoint.http2_keep_alive_interval(dur);
227 }
228 if let Some(dur) = self.config().http2_keep_alive_timeout {
229 endpoint = endpoint.keep_alive_timeout(dur);
230 }
231 if let Some(enabled) = self.config().http2_keep_alive_while_idle {
232 endpoint = endpoint.keep_alive_while_idle(enabled);
233 }
234 if let Some(enabled) = self.config().http2_adaptive_window {
235 endpoint = endpoint.http2_adaptive_window(enabled);
236 }
237 if let Some(tls_config) = &self.inner.client_tls_config {
238 endpoint = endpoint
239 .tls_config(tls_config.clone())
240 .context(CreateChannelSnafu)?;
241 }
242
243 endpoint = endpoint
244 .tcp_keepalive(self.config().tcp_keepalive)
245 .tcp_nodelay(self.config().tcp_nodelay);
246
247 Ok(endpoint)
248 }
249
250 fn trigger_channel_recycling(&self) {
251 if self
252 .inner
253 .channel_recycle_started
254 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
255 .is_err()
256 {
257 return;
258 }
259
260 let pool = self.pool().clone();
261 let cancel = self.inner.cancel.clone();
262 let id = self.inner.id;
263 let _handle = common_runtime::spawn_global(async move {
264 recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
265 });
266 info!(
267 "ChannelManager: {}, channel recycle is started, running in the background!",
268 self.inner.id
269 );
270 }
271}
272
273#[derive(Clone, Debug, PartialEq, Eq)]
274pub struct ClientTlsOption {
275 pub server_ca_cert_path: String,
276 pub client_cert_path: String,
277 pub client_key_path: String,
278}
279
280#[derive(Clone, Debug, PartialEq, Eq)]
281pub struct ChannelConfig {
282 pub timeout: Option<Duration>,
283 pub connect_timeout: Option<Duration>,
284 pub concurrency_limit: Option<usize>,
285 pub rate_limit: Option<(u64, Duration)>,
286 pub initial_stream_window_size: Option<u32>,
287 pub initial_connection_window_size: Option<u32>,
288 pub http2_keep_alive_interval: Option<Duration>,
289 pub http2_keep_alive_timeout: Option<Duration>,
290 pub http2_keep_alive_while_idle: Option<bool>,
291 pub http2_adaptive_window: Option<bool>,
292 pub tcp_keepalive: Option<Duration>,
293 pub tcp_nodelay: bool,
294 pub client_tls: Option<ClientTlsOption>,
295 pub max_recv_message_size: ReadableSize,
297 pub max_send_message_size: ReadableSize,
299}
300
301impl Default for ChannelConfig {
302 fn default() -> Self {
303 Self {
304 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
305 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
306 concurrency_limit: None,
307 rate_limit: None,
308 initial_stream_window_size: None,
309 initial_connection_window_size: None,
310 http2_keep_alive_interval: Some(Duration::from_secs(30)),
311 http2_keep_alive_timeout: None,
312 http2_keep_alive_while_idle: Some(true),
313 http2_adaptive_window: None,
314 tcp_keepalive: None,
315 tcp_nodelay: true,
316 client_tls: None,
317 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
318 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
319 }
320 }
321}
322
323impl ChannelConfig {
324 pub fn new() -> Self {
325 Default::default()
326 }
327
328 pub fn timeout(mut self, timeout: Duration) -> Self {
330 self.timeout = Some(timeout);
331 self
332 }
333
334 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
338 self.connect_timeout = Some(timeout);
339 self
340 }
341
342 pub fn concurrency_limit(mut self, limit: usize) -> Self {
344 self.concurrency_limit = Some(limit);
345 self
346 }
347
348 pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
350 self.rate_limit = Some((limit, duration));
351 self
352 }
353
354 pub fn initial_stream_window_size(mut self, size: u32) -> Self {
357 self.initial_stream_window_size = Some(size);
358 self
359 }
360
361 pub fn initial_connection_window_size(mut self, size: u32) -> Self {
365 self.initial_connection_window_size = Some(size);
366 self
367 }
368
369 pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
371 self.http2_keep_alive_interval = Some(duration);
372 self
373 }
374
375 pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
377 self.http2_keep_alive_timeout = Some(duration);
378 self
379 }
380
381 pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
383 self.http2_keep_alive_while_idle = Some(enabled);
384 self
385 }
386
387 pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
389 self.http2_adaptive_window = Some(enabled);
390 self
391 }
392
393 pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
400 self.tcp_keepalive = Some(duration);
401 self
402 }
403
404 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
408 self.tcp_nodelay = enabled;
409 self
410 }
411
412 pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
416 self.client_tls = Some(client_tls_option);
417 self
418 }
419}
420
421#[derive(Debug)]
422pub struct Channel {
423 channel: InnerChannel,
424 access: AtomicUsize,
425 use_default_connector: bool,
426}
427
428impl Channel {
429 #[inline]
430 pub fn access(&self) -> usize {
431 self.access.load(Ordering::Relaxed)
432 }
433
434 #[inline]
435 pub fn use_default_connector(&self) -> bool {
436 self.use_default_connector
437 }
438
439 #[inline]
440 pub fn increase_access(&self) {
441 let _ = self.access.fetch_add(1, Ordering::Relaxed);
442 }
443}
444
445#[derive(Debug, Default)]
446struct Pool {
447 channels: DashMap<String, Channel>,
448}
449
450impl Pool {
451 fn get(&self, addr: &str) -> Option<InnerChannel> {
452 let channel = self.channels.get(addr);
453 channel.map(|ch| {
454 ch.increase_access();
455 ch.channel.clone()
456 })
457 }
458
459 fn entry(&self, addr: String) -> Entry<String, Channel> {
460 self.channels.entry(addr)
461 }
462
463 #[cfg(test)]
464 fn get_access(&self, addr: &str) -> Option<usize> {
465 let channel = self.channels.get(addr);
466 channel.map(|ch| ch.access())
467 }
468
469 fn put(&self, addr: &str, channel: Channel) {
470 let _ = self.channels.insert(addr.to_string(), channel);
471 }
472
473 fn retain_channel<F>(&self, f: F)
474 where
475 F: FnMut(&String, &mut Channel) -> bool,
476 {
477 self.channels.retain(f);
478 }
479}
480
481async fn recycle_channel_in_loop(
482 pool: Arc<Pool>,
483 id: u64,
484 cancel: CancellationToken,
485 interval_secs: u64,
486) {
487 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
488
489 loop {
490 tokio::select! {
491 _ = cancel.cancelled() => {
492 info!("Stop channel recycle, ChannelManager id: {}", id);
493 break;
494 },
495 _ = interval.tick() => {}
496 }
497
498 pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use tower::service_fn;
505
506 use super::*;
507
508 #[should_panic]
509 #[test]
510 fn test_invalid_addr() {
511 let mgr = ChannelManager::default();
512 let addr = "http://test";
513
514 let _ = mgr.get(addr).unwrap();
515 }
516
517 #[tokio::test]
518 async fn test_access_count() {
519 let mgr = ChannelManager::new();
520 mgr.inner
522 .channel_recycle_started
523 .store(true, Ordering::Relaxed);
524 let mgr = Arc::new(mgr);
525 let addr = "test_uri";
526
527 let mut joins = Vec::with_capacity(10);
528 for _ in 0..10 {
529 let mgr_clone = mgr.clone();
530 let join = tokio::spawn(async move {
531 for _ in 0..100 {
532 let _ = mgr_clone.get(addr);
533 }
534 });
535 joins.push(join);
536 }
537 for join in joins {
538 join.await.unwrap();
539 }
540
541 assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
542
543 mgr.pool()
544 .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
545
546 assert_eq!(0, mgr.pool().get_access(addr).unwrap());
547 }
548
549 #[test]
550 fn test_config() {
551 let default_cfg = ChannelConfig::new();
552 assert_eq!(
553 ChannelConfig {
554 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
555 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
556 concurrency_limit: None,
557 rate_limit: None,
558 initial_stream_window_size: None,
559 initial_connection_window_size: None,
560 http2_keep_alive_interval: Some(Duration::from_secs(30)),
561 http2_keep_alive_timeout: None,
562 http2_keep_alive_while_idle: Some(true),
563 http2_adaptive_window: None,
564 tcp_keepalive: None,
565 tcp_nodelay: true,
566 client_tls: None,
567 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
568 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
569 },
570 default_cfg
571 );
572
573 let cfg = default_cfg
574 .timeout(Duration::from_secs(3))
575 .connect_timeout(Duration::from_secs(5))
576 .concurrency_limit(6)
577 .rate_limit(5, Duration::from_secs(1))
578 .initial_stream_window_size(10)
579 .initial_connection_window_size(20)
580 .http2_keep_alive_interval(Duration::from_secs(1))
581 .http2_keep_alive_timeout(Duration::from_secs(3))
582 .http2_keep_alive_while_idle(true)
583 .http2_adaptive_window(true)
584 .tcp_keepalive(Duration::from_secs(2))
585 .tcp_nodelay(false)
586 .client_tls_config(ClientTlsOption {
587 server_ca_cert_path: "some_server_path".to_string(),
588 client_cert_path: "some_cert_path".to_string(),
589 client_key_path: "some_key_path".to_string(),
590 });
591
592 assert_eq!(
593 ChannelConfig {
594 timeout: Some(Duration::from_secs(3)),
595 connect_timeout: Some(Duration::from_secs(5)),
596 concurrency_limit: Some(6),
597 rate_limit: Some((5, Duration::from_secs(1))),
598 initial_stream_window_size: Some(10),
599 initial_connection_window_size: Some(20),
600 http2_keep_alive_interval: Some(Duration::from_secs(1)),
601 http2_keep_alive_timeout: Some(Duration::from_secs(3)),
602 http2_keep_alive_while_idle: Some(true),
603 http2_adaptive_window: Some(true),
604 tcp_keepalive: Some(Duration::from_secs(2)),
605 tcp_nodelay: false,
606 client_tls: Some(ClientTlsOption {
607 server_ca_cert_path: "some_server_path".to_string(),
608 client_cert_path: "some_cert_path".to_string(),
609 client_key_path: "some_key_path".to_string(),
610 }),
611 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
612 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
613 },
614 cfg
615 );
616 }
617
618 #[test]
619 fn test_build_endpoint() {
620 let config = ChannelConfig::new()
621 .timeout(Duration::from_secs(3))
622 .connect_timeout(Duration::from_secs(5))
623 .concurrency_limit(6)
624 .rate_limit(5, Duration::from_secs(1))
625 .initial_stream_window_size(10)
626 .initial_connection_window_size(20)
627 .http2_keep_alive_interval(Duration::from_secs(1))
628 .http2_keep_alive_timeout(Duration::from_secs(3))
629 .http2_keep_alive_while_idle(true)
630 .http2_adaptive_window(true)
631 .tcp_keepalive(Duration::from_secs(2))
632 .tcp_nodelay(true);
633 let mgr = ChannelManager::with_config(config);
634
635 let res = mgr.build_endpoint("test_addr");
636
637 let _ = res.unwrap();
638 }
639
640 #[tokio::test]
641 async fn test_channel_with_connector() {
642 let mgr = ChannelManager::new();
643
644 let addr = "test_addr";
645 let res = mgr.get(addr);
646 let _ = res.unwrap();
647
648 mgr.retain_channel(|addr, channel| {
649 assert_eq!("test_addr", addr);
650 assert!(channel.use_default_connector());
651 true
652 });
653
654 let (client, _) = tokio::io::duplex(1024);
655 let mut client = Some(hyper_util::rt::TokioIo::new(client));
656 let res = mgr.reset_with_connector(
657 addr,
658 service_fn(move |_| {
659 let client = client.take().unwrap();
660 async move { Ok::<_, std::io::Error>(client) }
661 }),
662 );
663
664 let _ = res.unwrap();
665
666 mgr.retain_channel(|addr, channel| {
667 assert_eq!("test_addr", addr);
668 assert!(!channel.use_default_connector());
669 true
670 });
671 }
672
673 #[tokio::test]
674 async fn test_pool_release_with_channel_recycle() {
675 let mgr = ChannelManager::new();
676
677 let pool_holder = mgr.pool().clone();
678
679 let addr = "test_addr";
681 let _ = mgr.get(addr);
682
683 let mgr_clone_1 = mgr.clone();
684 let mgr_clone_2 = mgr.clone();
685 assert_eq!(3, Arc::strong_count(mgr.pool()));
686
687 drop(mgr_clone_1);
688 drop(mgr_clone_2);
689 assert_eq!(3, Arc::strong_count(mgr.pool()));
690
691 drop(mgr);
692
693 tokio::time::sleep(Duration::from_millis(10)).await;
695
696 assert_eq!(1, Arc::strong_count(&pool_holder));
697 }
698
699 #[tokio::test]
700 async fn test_pool_release_without_channel_recycle() {
701 let mgr = ChannelManager::new();
702
703 let pool_holder = mgr.pool().clone();
704
705 let mgr_clone_1 = mgr.clone();
706 let mgr_clone_2 = mgr.clone();
707 assert_eq!(2, Arc::strong_count(mgr.pool()));
708
709 drop(mgr_clone_1);
710 drop(mgr_clone_2);
711 assert_eq!(2, Arc::strong_count(mgr.pool()));
712
713 drop(mgr);
714
715 assert_eq!(1, Arc::strong_count(&pool_holder));
716 }
717}