1use std::future::Future;
16use std::sync::Arc;
17use std::time::Duration;
18
19use api::v1::meta::procedure_service_client::ProcedureServiceClient;
20use api::v1::meta::{
21 DdlTaskRequest, DdlTaskResponse, MigrateRegionRequest, MigrateRegionResponse,
22 ProcedureDetailRequest, ProcedureDetailResponse, ProcedureId, ProcedureStateResponse,
23 QueryProcedureRequest, ResponseHeader, Role,
24};
25use common_grpc::channel_manager::ChannelManager;
26use common_telemetry::tracing_context::TracingContext;
27use common_telemetry::{error, info, warn};
28use snafu::{ensure, ResultExt};
29use tokio::sync::RwLock;
30use tonic::codec::CompressionEncoding;
31use tonic::transport::Channel;
32use tonic::Status;
33
34use crate::client::ask_leader::AskLeader;
35use crate::client::{util, Id, LeaderProviderRef};
36use crate::error;
37use crate::error::Result;
38
39#[derive(Clone, Debug)]
40pub struct Client {
41 inner: Arc<RwLock<Inner>>,
42}
43
44impl Client {
45 pub fn new(id: Id, role: Role, channel_manager: ChannelManager, max_retry: usize) -> Self {
46 let inner = Arc::new(RwLock::new(Inner {
47 id,
48 role,
49 channel_manager,
50 leader_provider: None,
51 max_retry,
52 }));
53
54 Self { inner }
55 }
56
57 pub async fn start<U, A>(&mut self, urls: A) -> Result<()>
58 where
59 U: AsRef<str>,
60 A: AsRef<[U]>,
61 {
62 let mut inner = self.inner.write().await;
63 inner.start(urls)
64 }
65
66 pub(crate) async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()> {
68 let mut inner = self.inner.write().await;
69 inner.start_with(leader_provider)
70 }
71
72 pub async fn submit_ddl_task(&self, req: DdlTaskRequest) -> Result<DdlTaskResponse> {
73 let inner = self.inner.read().await;
74 inner.submit_ddl_task(req).await
75 }
76
77 pub async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
79 let inner = self.inner.read().await;
80 inner.query_procedure_state(pid).await
81 }
82
83 pub async fn migrate_region(
89 &self,
90 region_id: u64,
91 from_peer: u64,
92 to_peer: u64,
93 timeout: Duration,
94 ) -> Result<MigrateRegionResponse> {
95 let inner = self.inner.read().await;
96 inner
97 .migrate_region(region_id, from_peer, to_peer, timeout)
98 .await
99 }
100
101 pub async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
102 let inner = self.inner.read().await;
103 inner.list_procedures().await
104 }
105}
106
107#[derive(Debug)]
108struct Inner {
109 id: Id,
110 role: Role,
111 channel_manager: ChannelManager,
112 leader_provider: Option<LeaderProviderRef>,
113 max_retry: usize,
114}
115
116impl Inner {
117 fn start_with(&mut self, leader_provider: LeaderProviderRef) -> Result<()> {
118 ensure!(
119 !self.is_started(),
120 error::IllegalGrpcClientStateSnafu {
121 err_msg: "DDL client already started",
122 }
123 );
124 self.leader_provider = Some(leader_provider);
125 Ok(())
126 }
127
128 fn start<U, A>(&mut self, urls: A) -> Result<()>
129 where
130 U: AsRef<str>,
131 A: AsRef<[U]>,
132 {
133 let peers = urls
134 .as_ref()
135 .iter()
136 .map(|url| url.as_ref().to_string())
137 .collect::<Vec<_>>();
138 let ask_leader = AskLeader::new(
139 self.id,
140 self.role,
141 peers,
142 self.channel_manager.clone(),
143 self.max_retry,
144 );
145 self.start_with(Arc::new(ask_leader))
146 }
147
148 fn make_client(&self, addr: impl AsRef<str>) -> Result<ProcedureServiceClient<Channel>> {
149 let channel = self
150 .channel_manager
151 .get(addr)
152 .context(error::CreateChannelSnafu)?;
153
154 Ok(ProcedureServiceClient::new(channel)
155 .accept_compressed(CompressionEncoding::Gzip)
156 .accept_compressed(CompressionEncoding::Zstd)
157 .send_compressed(CompressionEncoding::Zstd))
158 }
159
160 #[inline]
161 fn is_started(&self) -> bool {
162 self.leader_provider.is_some()
163 }
164
165 async fn with_retry<T, F, R, H>(&self, task: &str, body_fn: F, get_header: H) -> Result<T>
166 where
167 R: Future<Output = std::result::Result<T, Status>>,
168 F: Fn(ProcedureServiceClient<Channel>) -> R,
169 H: Fn(&T) -> &Option<ResponseHeader>,
170 {
171 let Some(leader_provider) = self.leader_provider.as_ref() else {
172 return error::IllegalGrpcClientStateSnafu {
173 err_msg: "not started",
174 }
175 .fail();
176 };
177
178 let mut times = 0;
179 let mut last_error = None;
180
181 while times < self.max_retry {
182 if let Some(leader) = &leader_provider.leader() {
183 let client = self.make_client(leader)?;
184 match body_fn(client).await {
185 Ok(res) => {
186 if util::is_not_leader(get_header(&res)) {
187 last_error = Some(format!("{leader} is not a leader"));
188 warn!("Failed to {task} to {leader}, not a leader");
189 let leader = leader_provider.ask_leader().await?;
190 info!("DDL client updated to new leader addr: {leader}");
191 times += 1;
192 continue;
193 }
194 return Ok(res);
195 }
196 Err(status) => {
197 if util::is_unreachable(&status) {
199 last_error = Some(status.to_string());
200 warn!("Failed to {task} to {leader}, source: {status}");
201 let leader = leader_provider.ask_leader().await?;
202 info!("Procedure client updated to new leader addr: {leader}");
203 times += 1;
204 continue;
205 } else {
206 error!("An error occurred in gRPC, status: {status}");
207 return Err(error::Error::from(status));
208 }
209 }
210 }
211 } else if let Err(err) = leader_provider.ask_leader().await {
212 return Err(err);
213 }
214 }
215
216 error::RetryTimesExceededSnafu {
217 msg: format!("Failed to {task}, last error: {:?}", last_error),
218 times: self.max_retry,
219 }
220 .fail()
221 }
222
223 async fn migrate_region(
224 &self,
225 region_id: u64,
226 from_peer: u64,
227 to_peer: u64,
228 timeout: Duration,
229 ) -> Result<MigrateRegionResponse> {
230 let mut req = MigrateRegionRequest {
231 region_id,
232 from_peer,
233 to_peer,
234 timeout_secs: timeout.as_secs() as u32,
235 ..Default::default()
236 };
237
238 req.set_header(
239 self.id,
240 self.role,
241 TracingContext::from_current_span().to_w3c(),
242 );
243
244 self.with_retry(
245 "migrate region",
246 move |mut client| {
247 let req = req.clone();
248
249 async move { client.migrate(req).await.map(|res| res.into_inner()) }
250 },
251 |resp: &MigrateRegionResponse| &resp.header,
252 )
253 .await
254 }
255
256 async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse> {
257 let mut req = QueryProcedureRequest {
258 pid: Some(ProcedureId { key: pid.into() }),
259 ..Default::default()
260 };
261
262 req.set_header(
263 self.id,
264 self.role,
265 TracingContext::from_current_span().to_w3c(),
266 );
267
268 self.with_retry(
269 "query procedure state",
270 move |mut client| {
271 let req = req.clone();
272
273 async move { client.query(req).await.map(|res| res.into_inner()) }
274 },
275 |resp: &ProcedureStateResponse| &resp.header,
276 )
277 .await
278 }
279
280 async fn submit_ddl_task(&self, mut req: DdlTaskRequest) -> Result<DdlTaskResponse> {
281 req.set_header(
282 self.id,
283 self.role,
284 TracingContext::from_current_span().to_w3c(),
285 );
286
287 self.with_retry(
288 "submit ddl task",
289 move |mut client| {
290 let req = req.clone();
291 async move { client.ddl(req).await.map(|res| res.into_inner()) }
292 },
293 |resp: &DdlTaskResponse| &resp.header,
294 )
295 .await
296 }
297
298 async fn list_procedures(&self) -> Result<ProcedureDetailResponse> {
299 let mut req = ProcedureDetailRequest::default();
300 req.set_header(
301 self.id,
302 self.role,
303 TracingContext::from_current_span().to_w3c(),
304 );
305
306 self.with_retry(
307 "list procedure",
308 move |mut client| {
309 let req = req.clone();
310 async move { client.details(req).await.map(|res| res.into_inner()) }
311 },
312 |resp: &ProcedureDetailResponse| &resp.header,
313 )
314 .await
315 }
316}