1use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20
21use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
22use common_base::cancellation::CancellationHandle;
23use common_frontend::selector::{FrontendSelector, MetaClientSelector};
24use common_telemetry::{debug, info, warn};
25use common_time::util::current_time_millis;
26use meta_client::MetaClientRef;
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error;
30use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
31
32pub type ProcessId = u32;
33pub type ProcessManagerRef = Arc<ProcessManager>;
34
35pub struct ProcessManager {
37 server_addr: String,
39 next_id: AtomicU32,
41 catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
43 frontend_selector: Option<MetaClientSelector>,
45}
46
47impl ProcessManager {
48 pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
50 let frontend_selector = meta_client.map(MetaClientSelector::new);
51 Self {
52 server_addr,
53 next_id: Default::default(),
54 catalogs: Default::default(),
55 frontend_selector,
56 }
57 }
58}
59
60impl ProcessManager {
61 #[must_use]
63 pub fn register_query(
64 self: &Arc<Self>,
65 catalog: String,
66 schemas: Vec<String>,
67 query: String,
68 client: String,
69 query_id: Option<ProcessId>,
70 ) -> Ticket {
71 let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
72 let process = ProcessInfo {
73 id,
74 catalog: catalog.clone(),
75 schemas,
76 query,
77 start_timestamp: current_time_millis(),
78 client,
79 frontend: self.server_addr.clone(),
80 };
81 let cancellation_handle = Arc::new(CancellationHandle::default());
82 let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
83
84 self.catalogs
85 .write()
86 .unwrap()
87 .entry(catalog.clone())
88 .or_default()
89 .insert(id, cancellable_process);
90
91 Ticket {
92 catalog,
93 manager: self.clone(),
94 id,
95 cancellation_handle,
96 }
97 }
98
99 pub fn next_id(&self) -> u32 {
101 self.next_id.fetch_add(1, Ordering::Relaxed)
102 }
103
104 pub fn deregister_query(&self, catalog: String, id: ProcessId) {
106 if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
107 let process = o.get_mut().remove(&id);
108 debug!("Deregister process: {:?}", process);
109 if o.get().is_empty() {
110 o.remove();
111 }
112 }
113 }
114
115 pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
117 let catalogs = self.catalogs.read().unwrap();
118 let result = if let Some(catalog) = catalog {
119 if let Some(catalogs) = catalogs.get(catalog) {
120 catalogs.values().map(|p| p.process.clone()).collect()
121 } else {
122 vec![]
123 }
124 } else {
125 catalogs
126 .values()
127 .flat_map(|v| v.values().map(|p| p.process.clone()))
128 .collect()
129 };
130 Ok(result)
131 }
132
133 pub async fn list_all_processes(
134 &self,
135 catalog: Option<&str>,
136 ) -> error::Result<Vec<ProcessInfo>> {
137 let mut processes = vec![];
138 if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
139 let frontends = remote_frontend_selector
140 .select(|node| node.peer.addr != self.server_addr)
141 .await
142 .context(error::InvokeFrontendSnafu)?;
143 for mut f in frontends {
144 let result = f
145 .list_process(ListProcessRequest {
146 catalog: catalog.unwrap_or_default().to_string(),
147 })
148 .await
149 .context(error::InvokeFrontendSnafu);
150 match result {
151 Ok(resp) => {
152 processes.extend(resp.processes);
153 }
154 Err(e) => {
155 warn!(e; "Skipping failing node: {:?}", f)
156 }
157 }
158 }
159 }
160 processes.extend(self.local_processes(catalog)?);
161 Ok(processes)
162 }
163
164 pub async fn kill_process(
166 &self,
167 server_addr: String,
168 catalog: String,
169 id: ProcessId,
170 ) -> error::Result<bool> {
171 if server_addr == self.server_addr {
172 self.kill_local_process(catalog, id).await
173 } else {
174 let mut nodes = self
175 .frontend_selector
176 .as_ref()
177 .context(error::MetaClientMissingSnafu)?
178 .select(|node| node.peer.addr == server_addr)
179 .await
180 .context(error::InvokeFrontendSnafu)?;
181 ensure!(
182 !nodes.is_empty(),
183 error::FrontendNotFoundSnafu { addr: server_addr }
184 );
185
186 let request = KillProcessRequest {
187 server_addr,
188 catalog,
189 process_id: id,
190 };
191 nodes[0]
192 .kill_process(request)
193 .await
194 .context(error::InvokeFrontendSnafu)?;
195 Ok(true)
196 }
197 }
198
199 pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
201 if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
202 if let Some(process) = catalogs.remove(&id) {
203 process.handle.cancel();
204 info!(
205 "Killed process, catalog: {}, id: {:?}",
206 process.process.catalog, process.process.id
207 );
208 PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
209 Ok(true)
210 } else {
211 debug!("Failed to kill process, id not found: {}", id);
212 Ok(false)
213 }
214 } else {
215 debug!("Failed to kill process, catalog not found: {}", catalog);
216 Ok(false)
217 }
218 }
219}
220
221pub struct Ticket {
222 pub(crate) catalog: String,
223 pub(crate) manager: ProcessManagerRef,
224 pub(crate) id: ProcessId,
225 pub cancellation_handle: Arc<CancellationHandle>,
226}
227
228impl Drop for Ticket {
229 fn drop(&mut self) {
230 self.manager
231 .deregister_query(std::mem::take(&mut self.catalog), self.id);
232 }
233}
234
235struct CancellableProcess {
236 handle: Arc<CancellationHandle>,
237 process: ProcessInfo,
238}
239
240impl Drop for CancellableProcess {
241 fn drop(&mut self) {
242 PROCESS_LIST_COUNT
243 .with_label_values(&[&self.process.catalog])
244 .dec();
245 }
246}
247
248impl CancellableProcess {
249 fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
250 PROCESS_LIST_COUNT
251 .with_label_values(&[&process.catalog])
252 .inc();
253 Self { handle, process }
254 }
255}
256
257impl Debug for CancellableProcess {
258 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("CancellableProcess")
260 .field("cancelled", &self.handle.is_cancelled())
261 .field("process", &self.process)
262 .finish()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use std::sync::Arc;
269
270 use crate::process_manager::ProcessManager;
271
272 #[tokio::test]
273 async fn test_register_query() {
274 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
275 let ticket = process_manager.clone().register_query(
276 "public".to_string(),
277 vec!["test".to_string()],
278 "SELECT * FROM table".to_string(),
279 "".to_string(),
280 None,
281 );
282
283 let running_processes = process_manager.local_processes(None).unwrap();
284 assert_eq!(running_processes.len(), 1);
285 assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
286 assert_eq!(running_processes[0].id, ticket.id);
287 assert_eq!(&running_processes[0].query, "SELECT * FROM table");
288
289 drop(ticket);
290 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
291 }
292
293 #[tokio::test]
294 async fn test_register_query_with_custom_id() {
295 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
296 let custom_id = 12345;
297
298 let ticket = process_manager.clone().register_query(
299 "public".to_string(),
300 vec!["test".to_string()],
301 "SELECT * FROM table".to_string(),
302 "client1".to_string(),
303 Some(custom_id),
304 );
305
306 assert_eq!(ticket.id, custom_id);
307
308 let running_processes = process_manager.local_processes(None).unwrap();
309 assert_eq!(running_processes.len(), 1);
310 assert_eq!(running_processes[0].id, custom_id);
311 assert_eq!(&running_processes[0].client, "client1");
312 }
313
314 #[tokio::test]
315 async fn test_multiple_queries_same_catalog() {
316 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
317
318 let ticket1 = process_manager.clone().register_query(
319 "public".to_string(),
320 vec!["schema1".to_string()],
321 "SELECT * FROM table1".to_string(),
322 "client1".to_string(),
323 None,
324 );
325
326 let ticket2 = process_manager.clone().register_query(
327 "public".to_string(),
328 vec!["schema2".to_string()],
329 "SELECT * FROM table2".to_string(),
330 "client2".to_string(),
331 None,
332 );
333
334 let running_processes = process_manager.local_processes(Some("public")).unwrap();
335 assert_eq!(running_processes.len(), 2);
336
337 let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
339 assert!(ids.contains(&ticket1.id));
340 assert!(ids.contains(&ticket2.id));
341 }
342
343 #[tokio::test]
344 async fn test_multiple_catalogs() {
345 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
346
347 let _ticket1 = process_manager.clone().register_query(
348 "catalog1".to_string(),
349 vec!["schema1".to_string()],
350 "SELECT * FROM table1".to_string(),
351 "client1".to_string(),
352 None,
353 );
354
355 let _ticket2 = process_manager.clone().register_query(
356 "catalog2".to_string(),
357 vec!["schema2".to_string()],
358 "SELECT * FROM table2".to_string(),
359 "client2".to_string(),
360 None,
361 );
362
363 let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
365 assert_eq!(catalog1_processes.len(), 1);
366 assert_eq!(&catalog1_processes[0].catalog, "catalog1");
367
368 let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
369 assert_eq!(catalog2_processes.len(), 1);
370 assert_eq!(&catalog2_processes[0].catalog, "catalog2");
371
372 let all_processes = process_manager.local_processes(None).unwrap();
374 assert_eq!(all_processes.len(), 2);
375 }
376
377 #[tokio::test]
378 async fn test_deregister_query() {
379 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
380
381 let ticket = process_manager.clone().register_query(
382 "public".to_string(),
383 vec!["test".to_string()],
384 "SELECT * FROM table".to_string(),
385 "client1".to_string(),
386 None,
387 );
388 assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
389 process_manager.deregister_query("public".to_string(), ticket.id);
390 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
391 }
392
393 #[tokio::test]
394 async fn test_cancellation_handle() {
395 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
396
397 let ticket = process_manager.clone().register_query(
398 "public".to_string(),
399 vec!["test".to_string()],
400 "SELECT * FROM table".to_string(),
401 "client1".to_string(),
402 None,
403 );
404
405 assert!(!ticket.cancellation_handle.is_cancelled());
406 ticket.cancellation_handle.cancel();
407 assert!(ticket.cancellation_handle.is_cancelled());
408 }
409
410 #[tokio::test]
411 async fn test_kill_local_process() {
412 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
413
414 let ticket = process_manager.clone().register_query(
415 "public".to_string(),
416 vec!["test".to_string()],
417 "SELECT * FROM table".to_string(),
418 "client1".to_string(),
419 None,
420 );
421 assert!(!ticket.cancellation_handle.is_cancelled());
422 let killed = process_manager
423 .kill_process(
424 "127.0.0.1:8000".to_string(),
425 "public".to_string(),
426 ticket.id,
427 )
428 .await
429 .unwrap();
430
431 assert!(killed);
432 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
433 }
434
435 #[tokio::test]
436 async fn test_kill_nonexistent_process() {
437 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
438 let killed = process_manager
439 .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
440 .await
441 .unwrap();
442 assert!(!killed);
443 }
444
445 #[tokio::test]
446 async fn test_kill_process_nonexistent_catalog() {
447 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
448 let killed = process_manager
449 .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
450 .await
451 .unwrap();
452 assert!(!killed);
453 }
454
455 #[tokio::test]
456 async fn test_process_info_fields() {
457 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
458
459 let _ticket = process_manager.clone().register_query(
460 "test_catalog".to_string(),
461 vec!["schema1".to_string(), "schema2".to_string()],
462 "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
463 "test_client".to_string(),
464 Some(42),
465 );
466
467 let processes = process_manager.local_processes(None).unwrap();
468 assert_eq!(processes.len(), 1);
469
470 let process = &processes[0];
471 assert_eq!(process.id, 42);
472 assert_eq!(&process.catalog, "test_catalog");
473 assert_eq!(process.schemas, vec!["schema1", "schema2"]);
474 assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
475 assert_eq!(&process.client, "test_client");
476 assert_eq!(&process.frontend, "127.0.0.1:8000");
477 assert!(process.start_timestamp > 0);
478 }
479
480 #[tokio::test]
481 async fn test_ticket_drop_deregisters_process() {
482 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
483
484 {
485 let _ticket = process_manager.clone().register_query(
486 "public".to_string(),
487 vec!["test".to_string()],
488 "SELECT * FROM table".to_string(),
489 "client1".to_string(),
490 None,
491 );
492
493 assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
495 } assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
499 }
500}