meta_client/client/
procedure.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::sync::Arc;
17use std::time::Duration;
18
19use api::v1::meta::procedure_service_client::ProcedureServiceClient;
20use api::v1::meta::{
21    DdlTaskRequest, DdlTaskResponse, MigrateRegionRequest, MigrateRegionResponse,
22    ProcedureDetailRequest, ProcedureDetailResponse, ProcedureId, ProcedureStateResponse,
23    QueryProcedureRequest, ReconcileRequest, ReconcileResponse, ResponseHeader, Role,
24};
25use common_grpc::channel_manager::ChannelManager;
26use common_telemetry::tracing_context::TracingContext;
27use common_telemetry::{error, info, warn};
28use snafu::{ResultExt, ensure};
29use tokio::sync::RwLock;
30use tonic::codec::CompressionEncoding;
31use tonic::transport::Channel;
32use tonic::{Request, Status};
33
34use crate::client::{Id, LeaderProviderRef, util};
35use crate::error;
36use crate::error::Result;
37
38#[derive(Clone, Debug)]
39pub struct Client {
40    inner: Arc<RwLock<Inner>>,
41}
42
43impl Client {
44    pub fn new(
45        id: Id,
46        role: Role,
47        channel_manager: ChannelManager,
48        max_retry: usize,
49        timeout: Duration,
50    ) -> Self {
51        let inner = Arc::new(RwLock::new(Inner {
52            id,
53            role,
54            channel_manager,
55            leader_provider: None,
56            max_retry,
57            timeout,
58        }));
59
60        Self { inner }
61    }
62
63    /// Start the client with a [LeaderProvider].
64    pub(crate) async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()> {
65        let mut inner = self.inner.write().await;
66        inner.start_with(leader_provider)
67    }
68
69    pub async fn submit_ddl_task(&self, req: DdlTaskRequest) -> Result<DdlTaskResponse> {
70        let inner = self.inner.read().await;
71        inner.submit_ddl_task(req).await
72    }
73
74    /// Query the procedure' state by its id
75    pub async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
76        let inner = self.inner.read().await;
77        inner.query_procedure_state(pid).await
78    }
79
80    /// Migrate the region from one datanode to the other datanode:
81    /// - `region_id`:  the migrated region id
82    /// - `from_peer`:  the source datanode id
83    /// - `to_peer`:  the target datanode id
84    /// - `timeout`: timeout for downgrading region and upgrading region operations
85    pub async fn migrate_region(
86        &self,
87        region_id: u64,
88        from_peer: u64,
89        to_peer: u64,
90        timeout: Duration,
91    ) -> Result<MigrateRegionResponse> {
92        let inner = self.inner.read().await;
93        inner
94            .migrate_region(region_id, from_peer, to_peer, timeout)
95            .await
96    }
97
98    /// Reconcile the procedure state.
99    pub async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
100        let inner = self.inner.read().await;
101        inner.reconcile(request).await
102    }
103
104    pub async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
105        let inner = self.inner.read().await;
106        inner.list_procedures().await
107    }
108}
109
110#[derive(Debug)]
111struct Inner {
112    id: Id,
113    role: Role,
114    channel_manager: ChannelManager,
115    leader_provider: Option<LeaderProviderRef>,
116    max_retry: usize,
117    /// Request timeout.
118    timeout: Duration,
119}
120
121impl Inner {
122    fn start_with(&mut self, leader_provider: LeaderProviderRef) -> Result<()> {
123        ensure!(
124            !self.is_started(),
125            error::IllegalGrpcClientStateSnafu {
126                err_msg: "DDL client already started",
127            }
128        );
129        self.leader_provider = Some(leader_provider);
130        Ok(())
131    }
132
133    fn make_client(&self, addr: impl AsRef<str>) -> Result<ProcedureServiceClient<Channel>> {
134        let channel = self
135            .channel_manager
136            .get(addr)
137            .context(error::CreateChannelSnafu)?;
138
139        Ok(ProcedureServiceClient::new(channel)
140            .accept_compressed(CompressionEncoding::Gzip)
141            .accept_compressed(CompressionEncoding::Zstd)
142            .send_compressed(CompressionEncoding::Zstd))
143    }
144
145    #[inline]
146    fn is_started(&self) -> bool {
147        self.leader_provider.is_some()
148    }
149
150    async fn with_retry<T, F, R, H>(&self, task: &str, body_fn: F, get_header: H) -> Result<T>
151    where
152        R: Future<Output = std::result::Result<T, Status>>,
153        F: Fn(ProcedureServiceClient<Channel>) -> R,
154        H: Fn(&T) -> &Option<ResponseHeader>,
155    {
156        let Some(leader_provider) = self.leader_provider.as_ref() else {
157            return error::IllegalGrpcClientStateSnafu {
158                err_msg: "not started",
159            }
160            .fail();
161        };
162
163        let mut times = 0;
164        let mut last_error = None;
165
166        while times < self.max_retry {
167            if let Some(leader) = &leader_provider.leader() {
168                let client = self.make_client(leader)?;
169                match body_fn(client).await {
170                    Ok(res) => {
171                        if util::is_not_leader(get_header(&res)) {
172                            last_error = Some(format!("{leader} is not a leader"));
173                            warn!("Failed to {task} to {leader}, not a leader");
174                            let leader = leader_provider.ask_leader().await?;
175                            info!("DDL client updated to new leader addr: {leader}");
176                            times += 1;
177                            continue;
178                        }
179                        return Ok(res);
180                    }
181                    Err(status) => {
182                        // The leader may be unreachable.
183                        if util::is_unreachable(&status) {
184                            last_error = Some(status.to_string());
185                            warn!("Failed to {task} to {leader}, source: {status}");
186                            let leader = leader_provider.ask_leader().await?;
187                            info!("Procedure client updated to new leader addr: {leader}");
188                            times += 1;
189                            continue;
190                        } else {
191                            error!("An error occurred in gRPC, status: {status:?}");
192                            return Err(error::Error::from(status));
193                        }
194                    }
195                }
196            } else if let Err(err) = leader_provider.ask_leader().await {
197                return Err(err);
198            }
199        }
200
201        error::RetryTimesExceededSnafu {
202            msg: format!("Failed to {task}, last error: {:?}", last_error),
203            times: self.max_retry,
204        }
205        .fail()
206    }
207
208    async fn migrate_region(
209        &self,
210        region_id: u64,
211        from_peer: u64,
212        to_peer: u64,
213        timeout: Duration,
214    ) -> Result<MigrateRegionResponse> {
215        let mut req = MigrateRegionRequest {
216            region_id,
217            from_peer,
218            to_peer,
219            timeout_secs: timeout.as_secs() as u32,
220            ..Default::default()
221        };
222
223        req.set_header(
224            self.id,
225            self.role,
226            TracingContext::from_current_span().to_w3c(),
227        );
228
229        self.with_retry(
230            "migrate region",
231            move |mut client| {
232                let mut req = Request::new(req.clone());
233                req.set_timeout(self.timeout);
234
235                async move { client.migrate(req).await.map(|res| res.into_inner()) }
236            },
237            |resp: &MigrateRegionResponse| &resp.header,
238        )
239        .await
240    }
241
242    async fn reconcile(&self, request: ReconcileRequest) -> Result<ReconcileResponse> {
243        let mut req = request;
244        req.set_header(
245            self.id,
246            self.role,
247            TracingContext::from_current_span().to_w3c(),
248        );
249
250        self.with_retry(
251            "reconcile",
252            move |mut client| {
253                let mut req = Request::new(req.clone());
254                req.set_timeout(self.timeout);
255
256                async move { client.reconcile(req).await.map(|res| res.into_inner()) }
257            },
258            |resp: &ReconcileResponse| &resp.header,
259        )
260        .await
261    }
262
263    async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
264        let mut req = QueryProcedureRequest {
265            pid: Some(ProcedureId { key: pid.into() }),
266            ..Default::default()
267        };
268
269        req.set_header(
270            self.id,
271            self.role,
272            TracingContext::from_current_span().to_w3c(),
273        );
274
275        self.with_retry(
276            "query procedure state",
277            move |mut client| {
278                let mut req = Request::new(req.clone());
279                req.set_timeout(self.timeout);
280
281                async move { client.query(req).await.map(|res| res.into_inner()) }
282            },
283            |resp: &ProcedureStateResponse| &resp.header,
284        )
285        .await
286    }
287
288    async fn submit_ddl_task(&self, mut req: DdlTaskRequest) -> Result<DdlTaskResponse> {
289        req.set_header(
290            self.id,
291            self.role,
292            TracingContext::from_current_span().to_w3c(),
293        );
294        let timeout = Duration::from_secs(req.timeout_secs.into());
295
296        self.with_retry(
297            "submit ddl task",
298            move |mut client| {
299                let mut req = Request::new(req.clone());
300                req.set_timeout(timeout);
301                async move { client.ddl(req).await.map(|res| res.into_inner()) }
302            },
303            |resp: &DdlTaskResponse| &resp.header,
304        )
305        .await
306    }
307
308    async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
309        let mut req = ProcedureDetailRequest::default();
310        req.set_header(
311            self.id,
312            self.role,
313            TracingContext::from_current_span().to_w3c(),
314        );
315
316        self.with_retry(
317            "list procedure",
318            move |mut client| {
319                let mut req = Request::new(req.clone());
320                req.set_timeout(self.timeout);
321                async move { client.details(req).await.map(|res| res.into_inner()) }
322            },
323            |resp: &ProcedureDetailResponse| &resp.header,
324        )
325        .await
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use std::time::{Duration, Instant};
332
333    use api::v1::meta::heartbeat_server::{Heartbeat, HeartbeatServer};
334    use api::v1::meta::procedure_service_server::{ProcedureService, ProcedureServiceServer};
335    use api::v1::meta::{
336        AskLeaderRequest, AskLeaderResponse, DdlTaskRequest, DdlTaskResponse, HeartbeatRequest,
337        HeartbeatResponse, MigrateRegionRequest, MigrateRegionResponse, Peer,
338        ProcedureDetailRequest, ProcedureDetailResponse, ProcedureStateResponse,
339        QueryProcedureRequest, ReconcileRequest, ReconcileResponse, ResponseHeader, Role,
340    };
341    use async_trait::async_trait;
342    use common_error::status_code::StatusCode;
343    use common_meta::rpc::ddl::{CommentObjectType, CommentOnTask, DdlTask, SubmitDdlTaskRequest};
344    use common_telemetry::common_error::ext::ErrorExt;
345    use common_telemetry::info;
346    use session::context::QueryContext;
347    use tokio::net::TcpListener;
348    use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
349    use tonic::codec::CompressionEncoding;
350    use tonic::{Request, Response, Status};
351
352    use crate::client::MetaClientBuilder;
353
354    #[derive(Clone)]
355    struct MockHeartbeat {
356        leader_addr: String,
357    }
358
359    #[async_trait]
360    impl Heartbeat for MockHeartbeat {
361        type HeartbeatStream = ReceiverStream<Result<HeartbeatResponse, Status>>;
362
363        async fn heartbeat(
364            &self,
365            _request: Request<tonic::Streaming<HeartbeatRequest>>,
366        ) -> Result<Response<Self::HeartbeatStream>, Status> {
367            Err(Status::unimplemented(
368                "heartbeat stream is not used in this test",
369            ))
370        }
371
372        async fn ask_leader(
373            &self,
374            _request: Request<AskLeaderRequest>,
375        ) -> Result<Response<AskLeaderResponse>, Status> {
376            Ok(Response::new(AskLeaderResponse {
377                header: Some(ResponseHeader {
378                    protocol_version: 0,
379                    error: None,
380                }),
381                leader: Some(Peer {
382                    id: 1,
383                    addr: self.leader_addr.clone(),
384                }),
385            }))
386        }
387    }
388
389    #[derive(Clone)]
390    struct MockProcedure {
391        delay: Duration,
392    }
393
394    #[async_trait]
395    impl ProcedureService for MockProcedure {
396        async fn query(
397            &self,
398            _request: Request<QueryProcedureRequest>,
399        ) -> Result<Response<ProcedureStateResponse>, Status> {
400            Err(Status::unimplemented("query is not used in this test"))
401        }
402
403        async fn ddl(
404            &self,
405            _request: Request<DdlTaskRequest>,
406        ) -> Result<Response<DdlTaskResponse>, Status> {
407            tokio::time::sleep(self.delay).await;
408            Ok(Response::new(DdlTaskResponse {
409                header: Some(ResponseHeader {
410                    protocol_version: 0,
411                    error: None,
412                }),
413                ..Default::default()
414            }))
415        }
416
417        async fn reconcile(
418            &self,
419            _request: Request<ReconcileRequest>,
420        ) -> Result<Response<ReconcileResponse>, Status> {
421            Err(Status::unimplemented("reconcile is not used in this test"))
422        }
423
424        async fn migrate(
425            &self,
426            _request: Request<MigrateRegionRequest>,
427        ) -> Result<Response<MigrateRegionResponse>, Status> {
428            Err(Status::unimplemented("migrate is not used in this test"))
429        }
430
431        async fn details(
432            &self,
433            _request: Request<ProcedureDetailRequest>,
434        ) -> Result<Response<ProcedureDetailResponse>, Status> {
435            Err(Status::unimplemented("details is not used in this test"))
436        }
437    }
438
439    #[tokio::test(flavor = "multi_thread")]
440    async fn test_meta_client_ddl_request_timeout() {
441        common_telemetry::init_default_ut_logging();
442
443        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
444        let addr = listener.local_addr().unwrap();
445        let addr_str = addr.to_string();
446
447        let heartbeat = MockHeartbeat {
448            leader_addr: addr_str.clone(),
449        };
450        let procedure = MockProcedure {
451            delay: Duration::from_secs(4),
452        };
453
454        let server = tonic::transport::Server::builder()
455            .add_service(
456                HeartbeatServer::new(heartbeat).accept_compressed(CompressionEncoding::Zstd),
457            )
458            .add_service(
459                ProcedureServiceServer::new(procedure).accept_compressed(CompressionEncoding::Zstd),
460            )
461            .serve_with_incoming(TcpListenerStream::new(listener));
462        let server_handle = tokio::spawn(server);
463
464        let mut client = MetaClientBuilder::new(0, Role::Frontend)
465            .enable_heartbeat()
466            .enable_procedure()
467            .build();
468        client.start(&[addr_str.as_str()]).await.unwrap();
469
470        let mut request = SubmitDdlTaskRequest::new(
471            QueryContext::arc(),
472            DdlTask::new_comment_on(CommentOnTask {
473                catalog_name: "greptime".to_string(),
474                schema_name: "public".to_string(),
475                object_type: CommentObjectType::Table,
476                object_name: "test_table".to_string(),
477                column_name: None,
478                object_id: None,
479                comment: Some("timeout".to_string()),
480            }),
481        );
482        request.timeout = Duration::from_secs(1);
483
484        let now = Instant::now();
485        let err = client.submit_ddl_task(request).await.unwrap_err();
486        let elapsed = now.elapsed();
487        // The request should be cancelled within 1 second.
488        assert!(elapsed < Duration::from_secs(2));
489        info!("err: {err:?}, code: {}", err.status_code());
490        assert_eq!(err.status_code(), StatusCode::Cancelled);
491        let err_msg = err.to_string();
492        assert!(
493            err_msg.contains("Timeout expired"),
494            "unexpected error: {err_msg}"
495        );
496
497        server_handle.abort();
498    }
499}