1use std::future::Future;
16use std::sync::Arc;
17use std::time::Duration;
18
19use api::v1::meta::procedure_service_client::ProcedureServiceClient;
20use api::v1::meta::{
21 DdlTaskRequest, DdlTaskResponse, MigrateRegionRequest, MigrateRegionResponse,
22 ProcedureDetailRequest, ProcedureDetailResponse, ProcedureId, ProcedureStateResponse,
23 QueryProcedureRequest, ReconcileRequest, ReconcileResponse, ResponseHeader, Role,
24};
25use common_grpc::channel_manager::ChannelManager;
26use common_telemetry::tracing_context::TracingContext;
27use common_telemetry::{error, info, warn};
28use snafu::{ResultExt, ensure};
29use tokio::sync::RwLock;
30use tonic::codec::CompressionEncoding;
31use tonic::transport::Channel;
32use tonic::{Request, Status};
33
34use crate::client::{Id, LeaderProviderRef, util};
35use crate::error;
36use crate::error::Result;
37
38#[derive(Clone, Debug)]
39pub struct Client {
40 inner: Arc<RwLock<Inner>>,
41}
42
43impl Client {
44 pub fn new(
45 id: Id,
46 role: Role,
47 channel_manager: ChannelManager,
48 max_retry: usize,
49 timeout: Duration,
50 ) -> Self {
51 let inner = Arc::new(RwLock::new(Inner {
52 id,
53 role,
54 channel_manager,
55 leader_provider: None,
56 max_retry,
57 timeout,
58 }));
59
60 Self { inner }
61 }
62
63 pub(crate) async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()> {
65 let mut inner = self.inner.write().await;
66 inner.start_with(leader_provider)
67 }
68
69 pub async fn submit_ddl_task(&self, req: DdlTaskRequest) -> Result<DdlTaskResponse> {
70 let inner = self.inner.read().await;
71 inner.submit_ddl_task(req).await
72 }
73
74 pub async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
76 let inner = self.inner.read().await;
77 inner.query_procedure_state(pid).await
78 }
79
80 pub async fn migrate_region(
86 &self,
87 region_id: u64,
88 from_peer: u64,
89 to_peer: u64,
90 timeout: Duration,
91 ) -> Result<MigrateRegionResponse> {
92 let inner = self.inner.read().await;
93 inner
94 .migrate_region(region_id, from_peer, to_peer, timeout)
95 .await
96 }
97
98 pub async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
100 let inner = self.inner.read().await;
101 inner.reconcile(request).await
102 }
103
104 pub async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
105 let inner = self.inner.read().await;
106 inner.list_procedures().await
107 }
108}
109
110#[derive(Debug)]
111struct Inner {
112 id: Id,
113 role: Role,
114 channel_manager: ChannelManager,
115 leader_provider: Option<LeaderProviderRef>,
116 max_retry: usize,
117 timeout: Duration,
119}
120
121impl Inner {
122 fn start_with(&mut self, leader_provider: LeaderProviderRef) -> Result<()> {
123 ensure!(
124 !self.is_started(),
125 error::IllegalGrpcClientStateSnafu {
126 err_msg: "DDL client already started",
127 }
128 );
129 self.leader_provider = Some(leader_provider);
130 Ok(())
131 }
132
133 fn make_client(&self, addr: impl AsRef<str>) -> Result<ProcedureServiceClient<Channel>> {
134 let channel = self
135 .channel_manager
136 .get(addr)
137 .context(error::CreateChannelSnafu)?;
138
139 Ok(ProcedureServiceClient::new(channel)
140 .accept_compressed(CompressionEncoding::Gzip)
141 .accept_compressed(CompressionEncoding::Zstd)
142 .send_compressed(CompressionEncoding::Zstd))
143 }
144
145 #[inline]
146 fn is_started(&self) -> bool {
147 self.leader_provider.is_some()
148 }
149
150 async fn with_retry<T, F, R, H>(&self, task: &str, body_fn: F, get_header: H) -> Result<T>
151 where
152 R: Future<Output = std::result::Result<T, Status>>,
153 F: Fn(ProcedureServiceClient<Channel>) -> R,
154 H: Fn(&T) -> &Option<ResponseHeader>,
155 {
156 let Some(leader_provider) = self.leader_provider.as_ref() else {
157 return error::IllegalGrpcClientStateSnafu {
158 err_msg: "not started",
159 }
160 .fail();
161 };
162
163 let mut times = 0;
164 let mut last_error = None;
165
166 while times < self.max_retry {
167 if let Some(leader) = &leader_provider.leader() {
168 let client = self.make_client(leader)?;
169 match body_fn(client).await {
170 Ok(res) => {
171 if util::is_not_leader(get_header(&res)) {
172 last_error = Some(format!("{leader} is not a leader"));
173 warn!("Failed to {task} to {leader}, not a leader");
174 let leader = leader_provider.ask_leader().await?;
175 info!("DDL client updated to new leader addr: {leader}");
176 times += 1;
177 continue;
178 }
179 return Ok(res);
180 }
181 Err(status) => {
182 if util::is_unreachable(&status) {
184 last_error = Some(status.to_string());
185 warn!("Failed to {task} to {leader}, source: {status}");
186 let leader = leader_provider.ask_leader().await?;
187 info!("Procedure client updated to new leader addr: {leader}");
188 times += 1;
189 continue;
190 } else {
191 error!("An error occurred in gRPC, status: {status:?}");
192 return Err(error::Error::from(status));
193 }
194 }
195 }
196 } else if let Err(err) = leader_provider.ask_leader().await {
197 return Err(err);
198 }
199 }
200
201 error::RetryTimesExceededSnafu {
202 msg: format!("Failed to {task}, last error: {:?}", last_error),
203 times: self.max_retry,
204 }
205 .fail()
206 }
207
208 async fn migrate_region(
209 &self,
210 region_id: u64,
211 from_peer: u64,
212 to_peer: u64,
213 timeout: Duration,
214 ) -> Result<MigrateRegionResponse> {
215 let mut req = MigrateRegionRequest {
216 region_id,
217 from_peer,
218 to_peer,
219 timeout_secs: timeout.as_secs() as u32,
220 ..Default::default()
221 };
222
223 req.set_header(
224 self.id,
225 self.role,
226 TracingContext::from_current_span().to_w3c(),
227 );
228
229 self.with_retry(
230 "migrate region",
231 move |mut client| {
232 let mut req = Request::new(req.clone());
233 req.set_timeout(self.timeout);
234
235 async move { client.migrate(req).await.map(|res| res.into_inner()) }
236 },
237 |resp: &MigrateRegionResponse| &resp.header,
238 )
239 .await
240 }
241
242 async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
243 let mut req = request;
244 req.set_header(
245 self.id,
246 self.role,
247 TracingContext::from_current_span().to_w3c(),
248 );
249
250 self.with_retry(
251 "reconcile",
252 move |mut client| {
253 let mut req = Request::new(req.clone());
254 req.set_timeout(self.timeout);
255
256 async move { client.reconcile(req).await.map(|res| res.into_inner()) }
257 },
258 |resp: &ReconcileResponse| &resp.header,
259 )
260 .await
261 }
262
263 async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
264 let mut req = QueryProcedureRequest {
265 pid: Some(ProcedureId { key: pid.into() }),
266 ..Default::default()
267 };
268
269 req.set_header(
270 self.id,
271 self.role,
272 TracingContext::from_current_span().to_w3c(),
273 );
274
275 self.with_retry(
276 "query procedure state",
277 move |mut client| {
278 let mut req = Request::new(req.clone());
279 req.set_timeout(self.timeout);
280
281 async move { client.query(req).await.map(|res| res.into_inner()) }
282 },
283 |resp: &ProcedureStateResponse| &resp.header,
284 )
285 .await
286 }
287
288 async fn submit_ddl_task(&self, mut req: DdlTaskRequest) -> Result<DdlTaskResponse> {
289 req.set_header(
290 self.id,
291 self.role,
292 TracingContext::from_current_span().to_w3c(),
293 );
294 let timeout = Duration::from_secs(req.timeout_secs.into());
295
296 self.with_retry(
297 "submit ddl task",
298 move |mut client| {
299 let mut req = Request::new(req.clone());
300 req.set_timeout(timeout);
301 async move { client.ddl(req).await.map(|res| res.into_inner()) }
302 },
303 |resp: &DdlTaskResponse| &resp.header,
304 )
305 .await
306 }
307
308 async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
309 let mut req = ProcedureDetailRequest::default();
310 req.set_header(
311 self.id,
312 self.role,
313 TracingContext::from_current_span().to_w3c(),
314 );
315
316 self.with_retry(
317 "list procedure",
318 move |mut client| {
319 let mut req = Request::new(req.clone());
320 req.set_timeout(self.timeout);
321 async move { client.details(req).await.map(|res| res.into_inner()) }
322 },
323 |resp: &ProcedureDetailResponse| &resp.header,
324 )
325 .await
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use std::time::{Duration, Instant};
332
333 use api::v1::meta::heartbeat_server::{Heartbeat, HeartbeatServer};
334 use api::v1::meta::procedure_service_server::{ProcedureService, ProcedureServiceServer};
335 use api::v1::meta::{
336 AskLeaderRequest, AskLeaderResponse, DdlTaskRequest, DdlTaskResponse, HeartbeatRequest,
337 HeartbeatResponse, MigrateRegionRequest, MigrateRegionResponse, Peer,
338 ProcedureDetailRequest, ProcedureDetailResponse, ProcedureStateResponse,
339 QueryProcedureRequest, ReconcileRequest, ReconcileResponse, ResponseHeader, Role,
340 };
341 use async_trait::async_trait;
342 use common_error::status_code::StatusCode;
343 use common_meta::rpc::ddl::{CommentObjectType, CommentOnTask, DdlTask, SubmitDdlTaskRequest};
344 use common_telemetry::common_error::ext::ErrorExt;
345 use common_telemetry::info;
346 use session::context::QueryContext;
347 use tokio::net::TcpListener;
348 use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
349 use tonic::codec::CompressionEncoding;
350 use tonic::{Request, Response, Status};
351
352 use crate::client::MetaClientBuilder;
353
354 #[derive(Clone)]
355 struct MockHeartbeat {
356 leader_addr: String,
357 }
358
359 #[async_trait]
360 impl Heartbeat for MockHeartbeat {
361 type HeartbeatStream = ReceiverStream<Result<HeartbeatResponse, Status>>;
362
363 async fn heartbeat(
364 &self,
365 _request: Request<tonic::Streaming<HeartbeatRequest>>,
366 ) -> Result<Response<Self::HeartbeatStream>, Status> {
367 Err(Status::unimplemented(
368 "heartbeat stream is not used in this test",
369 ))
370 }
371
372 async fn ask_leader(
373 &self,
374 _request: Request<AskLeaderRequest>,
375 ) -> Result<Response<AskLeaderResponse>, Status> {
376 Ok(Response::new(AskLeaderResponse {
377 header: Some(ResponseHeader {
378 protocol_version: 0,
379 error: None,
380 }),
381 leader: Some(Peer {
382 id: 1,
383 addr: self.leader_addr.clone(),
384 }),
385 }))
386 }
387 }
388
389 #[derive(Clone)]
390 struct MockProcedure {
391 delay: Duration,
392 }
393
394 #[async_trait]
395 impl ProcedureService for MockProcedure {
396 async fn query(
397 &self,
398 _request: Request<QueryProcedureRequest>,
399 ) -> Result<Response<ProcedureStateResponse>, Status> {
400 Err(Status::unimplemented("query is not used in this test"))
401 }
402
403 async fn ddl(
404 &self,
405 _request: Request<DdlTaskRequest>,
406 ) -> Result<Response<DdlTaskResponse>, Status> {
407 tokio::time::sleep(self.delay).await;
408 Ok(Response::new(DdlTaskResponse {
409 header: Some(ResponseHeader {
410 protocol_version: 0,
411 error: None,
412 }),
413 ..Default::default()
414 }))
415 }
416
417 async fn reconcile(
418 &self,
419 _request: Request<ReconcileRequest>,
420 ) -> Result<Response<ReconcileResponse>, Status> {
421 Err(Status::unimplemented("reconcile is not used in this test"))
422 }
423
424 async fn migrate(
425 &self,
426 _request: Request<MigrateRegionRequest>,
427 ) -> Result<Response<MigrateRegionResponse>, Status> {
428 Err(Status::unimplemented("migrate is not used in this test"))
429 }
430
431 async fn details(
432 &self,
433 _request: Request<ProcedureDetailRequest>,
434 ) -> Result<Response<ProcedureDetailResponse>, Status> {
435 Err(Status::unimplemented("details is not used in this test"))
436 }
437 }
438
439 #[tokio::test(flavor = "multi_thread")]
440 async fn test_meta_client_ddl_request_timeout() {
441 common_telemetry::init_default_ut_logging();
442
443 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
444 let addr = listener.local_addr().unwrap();
445 let addr_str = addr.to_string();
446
447 let heartbeat = MockHeartbeat {
448 leader_addr: addr_str.clone(),
449 };
450 let procedure = MockProcedure {
451 delay: Duration::from_secs(4),
452 };
453
454 let server = tonic::transport::Server::builder()
455 .add_service(
456 HeartbeatServer::new(heartbeat).accept_compressed(CompressionEncoding::Zstd),
457 )
458 .add_service(
459 ProcedureServiceServer::new(procedure).accept_compressed(CompressionEncoding::Zstd),
460 )
461 .serve_with_incoming(TcpListenerStream::new(listener));
462 let server_handle = tokio::spawn(server);
463
464 let mut client = MetaClientBuilder::new(0, Role::Frontend)
465 .enable_heartbeat()
466 .enable_procedure()
467 .build();
468 client.start(&[addr_str.as_str()]).await.unwrap();
469
470 let mut request = SubmitDdlTaskRequest::new(
471 QueryContext::arc(),
472 DdlTask::new_comment_on(CommentOnTask {
473 catalog_name: "greptime".to_string(),
474 schema_name: "public".to_string(),
475 object_type: CommentObjectType::Table,
476 object_name: "test_table".to_string(),
477 column_name: None,
478 object_id: None,
479 comment: Some("timeout".to_string()),
480 }),
481 );
482 request.timeout = Duration::from_secs(1);
483
484 let now = Instant::now();
485 let err = client.submit_ddl_task(request).await.unwrap_err();
486 let elapsed = now.elapsed();
487 assert!(elapsed < Duration::from_secs(2));
489 info!("err: {err:?}, code: {}", err.status_code());
490 assert_eq!(err.status_code(), StatusCode::Cancelled);
491 let err_msg = err.to_string();
492 assert!(
493 err_msg.contains("Timeout expired"),
494 "unexpected error: {err_msg}"
495 );
496
497 server_handle.abort();
498 }
499}