1use std::future::Future;
16use std::sync::Arc;
17use std::time::Duration;
18
19use api::v1::meta::procedure_service_client::ProcedureServiceClient;
20use api::v1::meta::{
21 DdlTaskRequest, DdlTaskResponse, GcRegionsRequest, GcRegionsResponse, GcTableRequest,
22 GcTableResponse, MigrateRegionRequest, MigrateRegionResponse, ProcedureDetailRequest,
23 ProcedureDetailResponse, ProcedureId, ProcedureStateResponse, QueryProcedureRequest,
24 ReconcileRequest, ReconcileResponse, RequestHeader, ResponseHeader, Role,
25};
26use common_grpc::channel_manager::ChannelManager;
27use common_meta::rpc::procedure::{
28 GcRegionsRequest as MetaGcRegionsRequest, GcResponse as MetaGcResponse,
29 GcTableRequest as MetaGcTableRequest,
30};
31use common_telemetry::tracing_context::TracingContext;
32use common_telemetry::{error, info, warn};
33use snafu::{ResultExt, ensure};
34use tokio::sync::RwLock;
35use tonic::codec::CompressionEncoding;
36use tonic::transport::Channel;
37use tonic::{Request, Status};
38
39use crate::client::{Id, LeaderProviderRef, util};
40use crate::error;
41use crate::error::Result;
42
43#[derive(Clone, Debug)]
44pub struct Client {
45 inner: Arc<RwLock<Inner>>,
46}
47
48impl Client {
49 pub fn new(
50 id: Id,
51 role: Role,
52 channel_manager: ChannelManager,
53 max_retry: usize,
54 timeout: Duration,
55 ) -> Self {
56 let inner = Arc::new(RwLock::new(Inner {
57 id,
58 role,
59 channel_manager,
60 leader_provider: None,
61 max_retry,
62 timeout,
63 }));
64
65 Self { inner }
66 }
67
68 pub(crate) async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()> {
70 let mut inner = self.inner.write().await;
71 inner.start_with(leader_provider)
72 }
73
74 pub async fn submit_ddl_task(&self, req: DdlTaskRequest) -> Result<DdlTaskResponse> {
75 let inner = self.inner.read().await;
76 inner.submit_ddl_task(req).await
77 }
78
79 pub async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
81 let inner = self.inner.read().await;
82 inner.query_procedure_state(pid).await
83 }
84
85 pub async fn migrate_region(
91 &self,
92 region_id: u64,
93 from_peer: u64,
94 to_peer: u64,
95 timeout: Duration,
96 ) -> Result<MigrateRegionResponse> {
97 let inner = self.inner.read().await;
98 inner
99 .migrate_region(region_id, from_peer, to_peer, timeout)
100 .await
101 }
102
103 pub async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
105 let inner = self.inner.read().await;
106 inner.reconcile(request).await
107 }
108
109 pub async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
110 let inner = self.inner.read().await;
111 inner.list_procedures().await
112 }
113
114 pub async fn gc_regions(&self, request: MetaGcRegionsRequest) -> Result<MetaGcResponse> {
115 let inner = self.inner.read().await;
116 inner.gc_regions(request).await
117 }
118
119 pub async fn gc_table(&self, request: MetaGcTableRequest) -> Result<MetaGcResponse> {
120 let inner = self.inner.read().await;
121 inner.gc_table(request).await
122 }
123}
124
125#[derive(Debug)]
126struct Inner {
127 id: Id,
128 role: Role,
129 channel_manager: ChannelManager,
130 leader_provider: Option<LeaderProviderRef>,
131 max_retry: usize,
132 timeout: Duration,
134}
135
136impl Inner {
137 fn start_with(&mut self, leader_provider: LeaderProviderRef) -> Result<()> {
138 ensure!(
139 !self.is_started(),
140 error::IllegalGrpcClientStateSnafu {
141 err_msg: "DDL client already started",
142 }
143 );
144 self.leader_provider = Some(leader_provider);
145 Ok(())
146 }
147
148 fn make_client(&self, addr: impl AsRef<str>) -> Result<ProcedureServiceClient<Channel>> {
149 let channel = self
150 .channel_manager
151 .get(addr)
152 .context(error::CreateChannelSnafu)?;
153
154 Ok(ProcedureServiceClient::new(channel)
155 .accept_compressed(CompressionEncoding::Gzip)
156 .accept_compressed(CompressionEncoding::Zstd)
157 .send_compressed(CompressionEncoding::Zstd))
158 }
159
160 #[inline]
161 fn is_started(&self) -> bool {
162 self.leader_provider.is_some()
163 }
164
165 async fn with_retry<T, F, R, H>(&self, task: &str, body_fn: F, get_header: H) -> Result<T>
166 where
167 R: Future<Output = std::result::Result<T, Status>>,
168 F: Fn(ProcedureServiceClient<Channel>) -> R,
169 H: Fn(&T) -> &Option<ResponseHeader>,
170 {
171 let Some(leader_provider) = self.leader_provider.as_ref() else {
172 return error::IllegalGrpcClientStateSnafu {
173 err_msg: "not started",
174 }
175 .fail();
176 };
177
178 let mut times = 0;
179 let mut last_error = None;
180
181 while times < self.max_retry {
182 if let Some(leader) = &leader_provider.leader() {
183 let client = self.make_client(leader)?;
184 match body_fn(client).await {
185 Ok(res) => {
186 if util::is_not_leader(get_header(&res)) {
187 last_error = Some(format!("{leader} is not a leader"));
188 warn!("Failed to {task} to {leader}, not a leader");
189 let leader = leader_provider.ask_leader().await?;
190 info!("DDL client updated to new leader addr: {leader}");
191 times += 1;
192 continue;
193 }
194 return Ok(res);
195 }
196 Err(status) => {
197 if util::is_unreachable(&status) {
199 last_error = Some(status.to_string());
200 warn!("Failed to {task} to {leader}, source: {status}");
201 let leader = leader_provider.ask_leader().await?;
202 info!("Procedure client updated to new leader addr: {leader}");
203 times += 1;
204 continue;
205 } else {
206 error!("An error occurred in gRPC, status: {status:?}");
207 return Err(error::Error::from(status));
208 }
209 }
210 }
211 } else if let Err(err) = leader_provider.ask_leader().await {
212 return Err(err);
213 }
214 }
215
216 error::RetryTimesExceededSnafu {
217 msg: format!("Failed to {task}, last error: {:?}", last_error),
218 times: self.max_retry,
219 }
220 .fail()
221 }
222
223 async fn migrate_region(
224 &self,
225 region_id: u64,
226 from_peer: u64,
227 to_peer: u64,
228 timeout: Duration,
229 ) -> Result<MigrateRegionResponse> {
230 let mut req = MigrateRegionRequest {
231 region_id,
232 from_peer,
233 to_peer,
234 timeout_secs: timeout.as_secs() as u32,
235 ..Default::default()
236 };
237
238 req.set_header(
239 self.id,
240 self.role,
241 TracingContext::from_current_span().to_w3c(),
242 );
243
244 self.with_retry(
245 "migrate region",
246 move |mut client| {
247 let mut req = Request::new(req.clone());
248 req.set_timeout(self.timeout);
249
250 async move { client.migrate(req).await.map(|res| res.into_inner()) }
251 },
252 |resp: &MigrateRegionResponse| &resp.header,
253 )
254 .await
255 }
256
257 async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
258 let mut req = request;
259 req.set_header(
260 self.id,
261 self.role,
262 TracingContext::from_current_span().to_w3c(),
263 );
264
265 self.with_retry(
266 "reconcile",
267 move |mut client| {
268 let mut req = Request::new(req.clone());
269 req.set_timeout(self.timeout);
270
271 async move { client.reconcile(req).await.map(|res| res.into_inner()) }
272 },
273 |resp: &ReconcileResponse| &resp.header,
274 )
275 .await
276 }
277
278 async fn gc_regions(&self, request: MetaGcRegionsRequest) -> Result<MetaGcResponse> {
279 let timeout = request.timeout;
280 let req = GcRegionsRequest {
281 header: Some(RequestHeader {
282 protocol_version: 0,
283 member_id: self.id,
284 role: self.role as i32,
285 tracing_context: TracingContext::from_current_span().to_w3c(),
286 }),
287 region_ids: request.region_ids,
288 full_file_listing: request.full_file_listing,
289 timeout_secs: timeout.as_secs() as u32,
290 };
291
292 let resp: GcRegionsResponse = self
293 .with_retry(
294 "gc_regions",
295 move |mut client| {
296 let mut req = Request::new(req.clone());
297 req.set_timeout(timeout);
298 async move { client.gc_regions(req).await.map(|res| res.into_inner()) }
299 },
300 |resp: &GcRegionsResponse| &resp.header,
301 )
302 .await?;
303
304 let stats = resp.stats.unwrap_or_default();
305 Ok(MetaGcResponse {
306 processed_regions: stats.processed_regions,
307 need_retry_regions: stats.need_retry_regions,
308 deleted_files: stats.deleted_files,
309 deleted_indexes: stats.deleted_indexes,
310 })
311 }
312
313 async fn gc_table(&self, request: MetaGcTableRequest) -> Result<MetaGcResponse> {
314 let timeout = request.timeout;
315 let req = GcTableRequest {
316 header: Some(RequestHeader {
317 protocol_version: 0,
318 member_id: self.id,
319 role: self.role as i32,
320 tracing_context: TracingContext::from_current_span().to_w3c(),
321 }),
322 catalog_name: request.catalog_name,
323 schema_name: request.schema_name,
324 table_name: request.table_name,
325 full_file_listing: request.full_file_listing,
326 timeout_secs: timeout.as_secs() as u32,
327 };
328
329 let resp: GcTableResponse = self
330 .with_retry(
331 "gc_table",
332 move |mut client| {
333 let mut req = Request::new(req.clone());
334 req.set_timeout(timeout);
335 async move { client.gc_table(req).await.map(|res| res.into_inner()) }
336 },
337 |resp: &GcTableResponse| &resp.header,
338 )
339 .await?;
340
341 let stats = resp.stats.unwrap_or_default();
342 Ok(MetaGcResponse {
343 processed_regions: stats.processed_regions,
344 need_retry_regions: stats.need_retry_regions,
345 deleted_files: stats.deleted_files,
346 deleted_indexes: stats.deleted_indexes,
347 })
348 }
349
350 async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
351 let mut req = QueryProcedureRequest {
352 pid: Some(ProcedureId { key: pid.into() }),
353 ..Default::default()
354 };
355
356 req.set_header(
357 self.id,
358 self.role,
359 TracingContext::from_current_span().to_w3c(),
360 );
361
362 self.with_retry(
363 "query procedure state",
364 move |mut client| {
365 let mut req = Request::new(req.clone());
366 req.set_timeout(self.timeout);
367
368 async move { client.query(req).await.map(|res| res.into_inner()) }
369 },
370 |resp: &ProcedureStateResponse| &resp.header,
371 )
372 .await
373 }
374
375 async fn submit_ddl_task(&self, mut req: DdlTaskRequest) -> Result<DdlTaskResponse> {
376 req.set_header(
377 self.id,
378 self.role,
379 TracingContext::from_current_span().to_w3c(),
380 );
381 let timeout = Duration::from_secs(req.timeout_secs.into());
382
383 self.with_retry(
384 "submit ddl task",
385 move |mut client| {
386 let mut req = Request::new(req.clone());
387 req.set_timeout(timeout);
388 async move { client.ddl(req).await.map(|res| res.into_inner()) }
389 },
390 |resp: &DdlTaskResponse| &resp.header,
391 )
392 .await
393 }
394
395 async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
396 let mut req = ProcedureDetailRequest::default();
397 req.set_header(
398 self.id,
399 self.role,
400 TracingContext::from_current_span().to_w3c(),
401 );
402
403 self.with_retry(
404 "list procedure",
405 move |mut client| {
406 let mut req = Request::new(req.clone());
407 req.set_timeout(self.timeout);
408 async move { client.details(req).await.map(|res| res.into_inner()) }
409 },
410 |resp: &ProcedureDetailResponse| &resp.header,
411 )
412 .await
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use std::time::{Duration, Instant};
419
420 use api::v1::meta::heartbeat_server::{Heartbeat, HeartbeatServer};
421 use api::v1::meta::procedure_service_server::{ProcedureService, ProcedureServiceServer};
422 use api::v1::meta::{
423 AskLeaderRequest, AskLeaderResponse, DdlTaskRequest, DdlTaskResponse, GcRegionsRequest,
424 GcRegionsResponse, GcTableRequest, GcTableResponse, HeartbeatRequest, HeartbeatResponse,
425 MigrateRegionRequest, MigrateRegionResponse, Peer, ProcedureDetailRequest,
426 ProcedureDetailResponse, ProcedureStateResponse, QueryProcedureRequest, ReconcileRequest,
427 ReconcileResponse, ResponseHeader, Role,
428 };
429 use async_trait::async_trait;
430 use common_error::status_code::StatusCode;
431 use common_meta::rpc::ddl::{
432 CommentObjectType, CommentOnTask, DdlTask, QueryContext, SubmitDdlTaskRequest,
433 };
434 use common_telemetry::common_error::ext::ErrorExt;
435 use common_telemetry::info;
436 use tokio::net::TcpListener;
437 use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
438 use tonic::codec::CompressionEncoding;
439 use tonic::{Request, Response, Status};
440
441 use crate::client::MetaClientBuilder;
442
443 #[derive(Clone)]
444 struct MockHeartbeat {
445 leader_addr: String,
446 }
447
448 #[async_trait]
449 impl Heartbeat for MockHeartbeat {
450 type HeartbeatStream = ReceiverStream<Result<HeartbeatResponse, Status>>;
451
452 async fn heartbeat(
453 &self,
454 _request: Request<tonic::Streaming<HeartbeatRequest>>,
455 ) -> Result<Response<Self::HeartbeatStream>, Status> {
456 Err(Status::unimplemented(
457 "heartbeat stream is not used in this test",
458 ))
459 }
460
461 async fn ask_leader(
462 &self,
463 _request: Request<AskLeaderRequest>,
464 ) -> Result<Response<AskLeaderResponse>, Status> {
465 Ok(Response::new(AskLeaderResponse {
466 header: Some(ResponseHeader {
467 protocol_version: 0,
468 error: None,
469 }),
470 leader: Some(Peer {
471 id: 1,
472 addr: self.leader_addr.clone(),
473 }),
474 }))
475 }
476 }
477
478 #[derive(Clone)]
479 struct MockProcedure {
480 delay: Duration,
481 }
482
483 #[async_trait]
484 impl ProcedureService for MockProcedure {
485 async fn query(
486 &self,
487 _request: Request<QueryProcedureRequest>,
488 ) -> Result<Response<ProcedureStateResponse>, Status> {
489 Err(Status::unimplemented("query is not used in this test"))
490 }
491
492 async fn ddl(
493 &self,
494 _request: Request<DdlTaskRequest>,
495 ) -> Result<Response<DdlTaskResponse>, Status> {
496 tokio::time::sleep(self.delay).await;
497 Ok(Response::new(DdlTaskResponse {
498 header: Some(ResponseHeader {
499 protocol_version: 0,
500 error: None,
501 }),
502 ..Default::default()
503 }))
504 }
505
506 async fn reconcile(
507 &self,
508 _request: Request<ReconcileRequest>,
509 ) -> Result<Response<ReconcileResponse>, Status> {
510 Err(Status::unimplemented("reconcile is not used in this test"))
511 }
512
513 async fn migrate(
514 &self,
515 _request: Request<MigrateRegionRequest>,
516 ) -> Result<Response<MigrateRegionResponse>, Status> {
517 Err(Status::unimplemented("migrate is not used in this test"))
518 }
519
520 async fn details(
521 &self,
522 _request: Request<ProcedureDetailRequest>,
523 ) -> Result<Response<ProcedureDetailResponse>, Status> {
524 Err(Status::unimplemented("details is not used in this test"))
525 }
526
527 async fn gc_regions(
528 &self,
529 _request: Request<GcRegionsRequest>,
530 ) -> Result<Response<GcRegionsResponse>, Status> {
531 Err(Status::unimplemented("gc_regions is not used in this test"))
532 }
533
534 async fn gc_table(
535 &self,
536 _request: Request<GcTableRequest>,
537 ) -> Result<Response<GcTableResponse>, Status> {
538 Err(Status::unimplemented("gc_table is not used in this test"))
539 }
540 }
541
542 #[tokio::test(flavor = "multi_thread")]
543 async fn test_meta_client_ddl_request_timeout() {
544 common_telemetry::init_default_ut_logging();
545
546 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
547 let addr = listener.local_addr().unwrap();
548 let addr_str = addr.to_string();
549
550 let heartbeat = MockHeartbeat {
551 leader_addr: addr_str.clone(),
552 };
553 let procedure = MockProcedure {
554 delay: Duration::from_secs(4),
555 };
556
557 let server = tonic::transport::Server::builder()
558 .add_service(
559 HeartbeatServer::new(heartbeat).accept_compressed(CompressionEncoding::Zstd),
560 )
561 .add_service(
562 ProcedureServiceServer::new(procedure).accept_compressed(CompressionEncoding::Zstd),
563 )
564 .serve_with_incoming(TcpListenerStream::new(listener));
565 let server_handle = tokio::spawn(server);
566
567 let mut client = MetaClientBuilder::new(0, Role::Frontend)
568 .enable_heartbeat()
569 .enable_procedure()
570 .build();
571 client.start(&[addr_str.as_str()]).await.unwrap();
572
573 let mut request = SubmitDdlTaskRequest::new(
574 QueryContext::default(),
575 DdlTask::new_comment_on(CommentOnTask {
576 catalog_name: "greptime".to_string(),
577 schema_name: "public".to_string(),
578 object_type: CommentObjectType::Table,
579 object_name: "test_table".to_string(),
580 column_name: None,
581 object_id: None,
582 comment: Some("timeout".to_string()),
583 }),
584 );
585 request.timeout = Duration::from_secs(1);
586
587 let now = Instant::now();
588 let err = client.submit_ddl_task(request).await.unwrap_err();
589 let elapsed = now.elapsed();
590 assert!(elapsed < Duration::from_secs(2));
592 info!("err: {err:?}, code: {}", err.status_code());
593 assert_eq!(err.status_code(), StatusCode::Cancelled);
594 let err_msg = err.to_string();
595 assert!(
596 err_msg.contains("Timeout expired"),
597 "unexpected error: {err_msg}"
598 );
599
600 server_handle.abort();
601 }
602}