1use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20
21use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
22use common_base::cancellation::CancellationHandle;
23use common_frontend::selector::{FrontendSelector, MetaClientSelector};
24use common_telemetry::{debug, info};
25use common_time::util::current_time_millis;
26use meta_client::MetaClientRef;
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error;
30use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
31
32pub type ProcessId = u32;
33pub type ProcessManagerRef = Arc<ProcessManager>;
34
35pub struct ProcessManager {
37 server_addr: String,
39 next_id: AtomicU32,
41 catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
43 frontend_selector: Option<MetaClientSelector>,
45}
46
47impl ProcessManager {
48 pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
50 let frontend_selector = meta_client.map(MetaClientSelector::new);
51 Self {
52 server_addr,
53 next_id: Default::default(),
54 catalogs: Default::default(),
55 frontend_selector,
56 }
57 }
58}
59
60impl ProcessManager {
61 #[must_use]
63 pub fn register_query(
64 self: &Arc<Self>,
65 catalog: String,
66 schemas: Vec<String>,
67 query: String,
68 client: String,
69 query_id: Option<ProcessId>,
70 ) -> Ticket {
71 let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
72 let process = ProcessInfo {
73 id,
74 catalog: catalog.clone(),
75 schemas,
76 query,
77 start_timestamp: current_time_millis(),
78 client,
79 frontend: self.server_addr.clone(),
80 };
81 let cancellation_handle = Arc::new(CancellationHandle::default());
82 let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
83
84 self.catalogs
85 .write()
86 .unwrap()
87 .entry(catalog.clone())
88 .or_default()
89 .insert(id, cancellable_process);
90
91 Ticket {
92 catalog,
93 manager: self.clone(),
94 id,
95 cancellation_handle,
96 }
97 }
98
99 pub fn next_id(&self) -> u32 {
101 self.next_id.fetch_add(1, Ordering::Relaxed)
102 }
103
104 pub fn deregister_query(&self, catalog: String, id: ProcessId) {
106 if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
107 let process = o.get_mut().remove(&id);
108 debug!("Deregister process: {:?}", process);
109 if o.get().is_empty() {
110 o.remove();
111 }
112 }
113 }
114
115 pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
117 let catalogs = self.catalogs.read().unwrap();
118 let result = if let Some(catalog) = catalog {
119 if let Some(catalogs) = catalogs.get(catalog) {
120 catalogs.values().map(|p| p.process.clone()).collect()
121 } else {
122 vec![]
123 }
124 } else {
125 catalogs
126 .values()
127 .flat_map(|v| v.values().map(|p| p.process.clone()))
128 .collect()
129 };
130 Ok(result)
131 }
132
133 pub async fn list_all_processes(
134 &self,
135 catalog: Option<&str>,
136 ) -> error::Result<Vec<ProcessInfo>> {
137 let mut processes = vec![];
138 if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
139 let frontends = remote_frontend_selector
140 .select(|node| node.peer.addr != self.server_addr)
141 .await
142 .context(error::InvokeFrontendSnafu)?;
143 for mut f in frontends {
144 processes.extend(
145 f.list_process(ListProcessRequest {
146 catalog: catalog.unwrap_or_default().to_string(),
147 })
148 .await
149 .context(error::InvokeFrontendSnafu)?
150 .processes,
151 );
152 }
153 }
154 processes.extend(self.local_processes(catalog)?);
155 Ok(processes)
156 }
157
158 pub async fn kill_process(
160 &self,
161 server_addr: String,
162 catalog: String,
163 id: ProcessId,
164 ) -> error::Result<bool> {
165 if server_addr == self.server_addr {
166 self.kill_local_process(catalog, id).await
167 } else {
168 let mut nodes = self
169 .frontend_selector
170 .as_ref()
171 .context(error::MetaClientMissingSnafu)?
172 .select(|node| node.peer.addr == server_addr)
173 .await
174 .context(error::InvokeFrontendSnafu)?;
175 ensure!(
176 !nodes.is_empty(),
177 error::FrontendNotFoundSnafu { addr: server_addr }
178 );
179
180 let request = KillProcessRequest {
181 server_addr,
182 catalog,
183 process_id: id,
184 };
185 nodes[0]
186 .kill_process(request)
187 .await
188 .context(error::InvokeFrontendSnafu)?;
189 Ok(true)
190 }
191 }
192
193 pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
195 if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
196 if let Some(process) = catalogs.remove(&id) {
197 process.handle.cancel();
198 info!(
199 "Killed process, catalog: {}, id: {:?}",
200 process.process.catalog, process.process.id
201 );
202 PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
203 Ok(true)
204 } else {
205 debug!("Failed to kill process, id not found: {}", id);
206 Ok(false)
207 }
208 } else {
209 debug!("Failed to kill process, catalog not found: {}", catalog);
210 Ok(false)
211 }
212 }
213}
214
215pub struct Ticket {
216 pub(crate) catalog: String,
217 pub(crate) manager: ProcessManagerRef,
218 pub(crate) id: ProcessId,
219 pub cancellation_handle: Arc<CancellationHandle>,
220}
221
222impl Drop for Ticket {
223 fn drop(&mut self) {
224 self.manager
225 .deregister_query(std::mem::take(&mut self.catalog), self.id);
226 }
227}
228
229struct CancellableProcess {
230 handle: Arc<CancellationHandle>,
231 process: ProcessInfo,
232}
233
234impl Drop for CancellableProcess {
235 fn drop(&mut self) {
236 PROCESS_LIST_COUNT
237 .with_label_values(&[&self.process.catalog])
238 .dec();
239 }
240}
241
242impl CancellableProcess {
243 fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
244 PROCESS_LIST_COUNT
245 .with_label_values(&[&process.catalog])
246 .inc();
247 Self { handle, process }
248 }
249}
250
251impl Debug for CancellableProcess {
252 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("CancellableProcess")
254 .field("cancelled", &self.handle.is_cancelled())
255 .field("process", &self.process)
256 .finish()
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use std::sync::Arc;
263
264 use crate::process_manager::ProcessManager;
265
266 #[tokio::test]
267 async fn test_register_query() {
268 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
269 let ticket = process_manager.clone().register_query(
270 "public".to_string(),
271 vec!["test".to_string()],
272 "SELECT * FROM table".to_string(),
273 "".to_string(),
274 None,
275 );
276
277 let running_processes = process_manager.local_processes(None).unwrap();
278 assert_eq!(running_processes.len(), 1);
279 assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
280 assert_eq!(running_processes[0].id, ticket.id);
281 assert_eq!(&running_processes[0].query, "SELECT * FROM table");
282
283 drop(ticket);
284 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
285 }
286
287 #[tokio::test]
288 async fn test_register_query_with_custom_id() {
289 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
290 let custom_id = 12345;
291
292 let ticket = process_manager.clone().register_query(
293 "public".to_string(),
294 vec!["test".to_string()],
295 "SELECT * FROM table".to_string(),
296 "client1".to_string(),
297 Some(custom_id),
298 );
299
300 assert_eq!(ticket.id, custom_id);
301
302 let running_processes = process_manager.local_processes(None).unwrap();
303 assert_eq!(running_processes.len(), 1);
304 assert_eq!(running_processes[0].id, custom_id);
305 assert_eq!(&running_processes[0].client, "client1");
306 }
307
308 #[tokio::test]
309 async fn test_multiple_queries_same_catalog() {
310 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
311
312 let ticket1 = process_manager.clone().register_query(
313 "public".to_string(),
314 vec!["schema1".to_string()],
315 "SELECT * FROM table1".to_string(),
316 "client1".to_string(),
317 None,
318 );
319
320 let ticket2 = process_manager.clone().register_query(
321 "public".to_string(),
322 vec!["schema2".to_string()],
323 "SELECT * FROM table2".to_string(),
324 "client2".to_string(),
325 None,
326 );
327
328 let running_processes = process_manager.local_processes(Some("public")).unwrap();
329 assert_eq!(running_processes.len(), 2);
330
331 let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
333 assert!(ids.contains(&ticket1.id));
334 assert!(ids.contains(&ticket2.id));
335 }
336
337 #[tokio::test]
338 async fn test_multiple_catalogs() {
339 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
340
341 let _ticket1 = process_manager.clone().register_query(
342 "catalog1".to_string(),
343 vec!["schema1".to_string()],
344 "SELECT * FROM table1".to_string(),
345 "client1".to_string(),
346 None,
347 );
348
349 let _ticket2 = process_manager.clone().register_query(
350 "catalog2".to_string(),
351 vec!["schema2".to_string()],
352 "SELECT * FROM table2".to_string(),
353 "client2".to_string(),
354 None,
355 );
356
357 let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
359 assert_eq!(catalog1_processes.len(), 1);
360 assert_eq!(&catalog1_processes[0].catalog, "catalog1");
361
362 let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
363 assert_eq!(catalog2_processes.len(), 1);
364 assert_eq!(&catalog2_processes[0].catalog, "catalog2");
365
366 let all_processes = process_manager.local_processes(None).unwrap();
368 assert_eq!(all_processes.len(), 2);
369 }
370
371 #[tokio::test]
372 async fn test_deregister_query() {
373 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
374
375 let ticket = process_manager.clone().register_query(
376 "public".to_string(),
377 vec!["test".to_string()],
378 "SELECT * FROM table".to_string(),
379 "client1".to_string(),
380 None,
381 );
382 assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
383 process_manager.deregister_query("public".to_string(), ticket.id);
384 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
385 }
386
387 #[tokio::test]
388 async fn test_cancellation_handle() {
389 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
390
391 let ticket = process_manager.clone().register_query(
392 "public".to_string(),
393 vec!["test".to_string()],
394 "SELECT * FROM table".to_string(),
395 "client1".to_string(),
396 None,
397 );
398
399 assert!(!ticket.cancellation_handle.is_cancelled());
400 ticket.cancellation_handle.cancel();
401 assert!(ticket.cancellation_handle.is_cancelled());
402 }
403
404 #[tokio::test]
405 async fn test_kill_local_process() {
406 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
407
408 let ticket = process_manager.clone().register_query(
409 "public".to_string(),
410 vec!["test".to_string()],
411 "SELECT * FROM table".to_string(),
412 "client1".to_string(),
413 None,
414 );
415 assert!(!ticket.cancellation_handle.is_cancelled());
416 let killed = process_manager
417 .kill_process(
418 "127.0.0.1:8000".to_string(),
419 "public".to_string(),
420 ticket.id,
421 )
422 .await
423 .unwrap();
424
425 assert!(killed);
426 assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
427 }
428
429 #[tokio::test]
430 async fn test_kill_nonexistent_process() {
431 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
432 let killed = process_manager
433 .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
434 .await
435 .unwrap();
436 assert!(!killed);
437 }
438
439 #[tokio::test]
440 async fn test_kill_process_nonexistent_catalog() {
441 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
442 let killed = process_manager
443 .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
444 .await
445 .unwrap();
446 assert!(!killed);
447 }
448
449 #[tokio::test]
450 async fn test_process_info_fields() {
451 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
452
453 let _ticket = process_manager.clone().register_query(
454 "test_catalog".to_string(),
455 vec!["schema1".to_string(), "schema2".to_string()],
456 "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
457 "test_client".to_string(),
458 Some(42),
459 );
460
461 let processes = process_manager.local_processes(None).unwrap();
462 assert_eq!(processes.len(), 1);
463
464 let process = &processes[0];
465 assert_eq!(process.id, 42);
466 assert_eq!(&process.catalog, "test_catalog");
467 assert_eq!(process.schemas, vec!["schema1", "schema2"]);
468 assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
469 assert_eq!(&process.client, "test_client");
470 assert_eq!(&process.frontend, "127.0.0.1:8000");
471 assert!(process.start_timestamp > 0);
472 }
473
474 #[tokio::test]
475 async fn test_ticket_drop_deregisters_process() {
476 let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
477
478 {
479 let _ticket = process_manager.clone().register_query(
480 "public".to_string(),
481 vec!["test".to_string()],
482 "SELECT * FROM table".to_string(),
483 "client1".to_string(),
484 None,
485 );
486
487 assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
489 } assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
493 }
494}