catalog/
process_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20
21use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
22use common_base::cancellation::CancellationHandle;
23use common_frontend::selector::{FrontendSelector, MetaClientSelector};
24use common_telemetry::{debug, info};
25use common_time::util::current_time_millis;
26use meta_client::MetaClientRef;
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error;
30use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
31
32pub type ProcessId = u32;
33pub type ProcessManagerRef = Arc<ProcessManager>;
34
35/// Query process manager.
36pub struct ProcessManager {
37    /// Local frontend server address,
38    server_addr: String,
39    /// Next process id for local queries.
40    next_id: AtomicU32,
41    /// Running process per catalog.
42    catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
43    /// Frontend selector to locate frontend nodes.
44    frontend_selector: Option<MetaClientSelector>,
45}
46
47impl ProcessManager {
48    /// Create a [ProcessManager] instance with server address and kv client.
49    pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
50        let frontend_selector = meta_client.map(MetaClientSelector::new);
51        Self {
52            server_addr,
53            next_id: Default::default(),
54            catalogs: Default::default(),
55            frontend_selector,
56        }
57    }
58}
59
60impl ProcessManager {
61    /// Registers a submitted query. Use the provided id if present.
62    #[must_use]
63    pub fn register_query(
64        self: &Arc<Self>,
65        catalog: String,
66        schemas: Vec<String>,
67        query: String,
68        client: String,
69        query_id: Option<ProcessId>,
70    ) -> Ticket {
71        let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
72        let process = ProcessInfo {
73            id,
74            catalog: catalog.clone(),
75            schemas,
76            query,
77            start_timestamp: current_time_millis(),
78            client,
79            frontend: self.server_addr.clone(),
80        };
81        let cancellation_handle = Arc::new(CancellationHandle::default());
82        let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
83
84        self.catalogs
85            .write()
86            .unwrap()
87            .entry(catalog.clone())
88            .or_default()
89            .insert(id, cancellable_process);
90
91        Ticket {
92            catalog,
93            manager: self.clone(),
94            id,
95            cancellation_handle,
96        }
97    }
98
99    /// Generates the next process id.
100    pub fn next_id(&self) -> u32 {
101        self.next_id.fetch_add(1, Ordering::Relaxed)
102    }
103
104    /// De-register a query from process list.
105    pub fn deregister_query(&self, catalog: String, id: ProcessId) {
106        if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
107            let process = o.get_mut().remove(&id);
108            debug!("Deregister process: {:?}", process);
109            if o.get().is_empty() {
110                o.remove();
111            }
112        }
113    }
114
115    /// List local running processes in given catalog.
116    pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
117        let catalogs = self.catalogs.read().unwrap();
118        let result = if let Some(catalog) = catalog {
119            if let Some(catalogs) = catalogs.get(catalog) {
120                catalogs.values().map(|p| p.process.clone()).collect()
121            } else {
122                vec![]
123            }
124        } else {
125            catalogs
126                .values()
127                .flat_map(|v| v.values().map(|p| p.process.clone()))
128                .collect()
129        };
130        Ok(result)
131    }
132
133    pub async fn list_all_processes(
134        &self,
135        catalog: Option<&str>,
136    ) -> error::Result<Vec<ProcessInfo>> {
137        let mut processes = vec![];
138        if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
139            let frontends = remote_frontend_selector
140                .select(|node| node.peer.addr != self.server_addr)
141                .await
142                .context(error::InvokeFrontendSnafu)?;
143            for mut f in frontends {
144                processes.extend(
145                    f.list_process(ListProcessRequest {
146                        catalog: catalog.unwrap_or_default().to_string(),
147                    })
148                    .await
149                    .context(error::InvokeFrontendSnafu)?
150                    .processes,
151                );
152            }
153        }
154        processes.extend(self.local_processes(catalog)?);
155        Ok(processes)
156    }
157
158    /// Kills query with provided catalog and id.
159    pub async fn kill_process(
160        &self,
161        server_addr: String,
162        catalog: String,
163        id: ProcessId,
164    ) -> error::Result<bool> {
165        if server_addr == self.server_addr {
166            self.kill_local_process(catalog, id).await
167        } else {
168            let mut nodes = self
169                .frontend_selector
170                .as_ref()
171                .context(error::MetaClientMissingSnafu)?
172                .select(|node| node.peer.addr == server_addr)
173                .await
174                .context(error::InvokeFrontendSnafu)?;
175            ensure!(
176                !nodes.is_empty(),
177                error::FrontendNotFoundSnafu { addr: server_addr }
178            );
179
180            let request = KillProcessRequest {
181                server_addr,
182                catalog,
183                process_id: id,
184            };
185            nodes[0]
186                .kill_process(request)
187                .await
188                .context(error::InvokeFrontendSnafu)?;
189            Ok(true)
190        }
191    }
192
193    /// Kills local query with provided catalog and id.
194    pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
195        if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
196            if let Some(process) = catalogs.remove(&id) {
197                process.handle.cancel();
198                info!(
199                    "Killed process, catalog: {}, id: {:?}",
200                    process.process.catalog, process.process.id
201                );
202                PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
203                Ok(true)
204            } else {
205                debug!("Failed to kill process, id not found: {}", id);
206                Ok(false)
207            }
208        } else {
209            debug!("Failed to kill process, catalog not found: {}", catalog);
210            Ok(false)
211        }
212    }
213}
214
215pub struct Ticket {
216    pub(crate) catalog: String,
217    pub(crate) manager: ProcessManagerRef,
218    pub(crate) id: ProcessId,
219    pub cancellation_handle: Arc<CancellationHandle>,
220}
221
222impl Drop for Ticket {
223    fn drop(&mut self) {
224        self.manager
225            .deregister_query(std::mem::take(&mut self.catalog), self.id);
226    }
227}
228
229struct CancellableProcess {
230    handle: Arc<CancellationHandle>,
231    process: ProcessInfo,
232}
233
234impl Drop for CancellableProcess {
235    fn drop(&mut self) {
236        PROCESS_LIST_COUNT
237            .with_label_values(&[&self.process.catalog])
238            .dec();
239    }
240}
241
242impl CancellableProcess {
243    fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
244        PROCESS_LIST_COUNT
245            .with_label_values(&[&process.catalog])
246            .inc();
247        Self { handle, process }
248    }
249}
250
251impl Debug for CancellableProcess {
252    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("CancellableProcess")
254            .field("cancelled", &self.handle.is_cancelled())
255            .field("process", &self.process)
256            .finish()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::sync::Arc;
263
264    use crate::process_manager::ProcessManager;
265
266    #[tokio::test]
267    async fn test_register_query() {
268        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
269        let ticket = process_manager.clone().register_query(
270            "public".to_string(),
271            vec!["test".to_string()],
272            "SELECT * FROM table".to_string(),
273            "".to_string(),
274            None,
275        );
276
277        let running_processes = process_manager.local_processes(None).unwrap();
278        assert_eq!(running_processes.len(), 1);
279        assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
280        assert_eq!(running_processes[0].id, ticket.id);
281        assert_eq!(&running_processes[0].query, "SELECT * FROM table");
282
283        drop(ticket);
284        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
285    }
286
287    #[tokio::test]
288    async fn test_register_query_with_custom_id() {
289        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
290        let custom_id = 12345;
291
292        let ticket = process_manager.clone().register_query(
293            "public".to_string(),
294            vec!["test".to_string()],
295            "SELECT * FROM table".to_string(),
296            "client1".to_string(),
297            Some(custom_id),
298        );
299
300        assert_eq!(ticket.id, custom_id);
301
302        let running_processes = process_manager.local_processes(None).unwrap();
303        assert_eq!(running_processes.len(), 1);
304        assert_eq!(running_processes[0].id, custom_id);
305        assert_eq!(&running_processes[0].client, "client1");
306    }
307
308    #[tokio::test]
309    async fn test_multiple_queries_same_catalog() {
310        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
311
312        let ticket1 = process_manager.clone().register_query(
313            "public".to_string(),
314            vec!["schema1".to_string()],
315            "SELECT * FROM table1".to_string(),
316            "client1".to_string(),
317            None,
318        );
319
320        let ticket2 = process_manager.clone().register_query(
321            "public".to_string(),
322            vec!["schema2".to_string()],
323            "SELECT * FROM table2".to_string(),
324            "client2".to_string(),
325            None,
326        );
327
328        let running_processes = process_manager.local_processes(Some("public")).unwrap();
329        assert_eq!(running_processes.len(), 2);
330
331        // Verify both processes are present
332        let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
333        assert!(ids.contains(&ticket1.id));
334        assert!(ids.contains(&ticket2.id));
335    }
336
337    #[tokio::test]
338    async fn test_multiple_catalogs() {
339        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
340
341        let _ticket1 = process_manager.clone().register_query(
342            "catalog1".to_string(),
343            vec!["schema1".to_string()],
344            "SELECT * FROM table1".to_string(),
345            "client1".to_string(),
346            None,
347        );
348
349        let _ticket2 = process_manager.clone().register_query(
350            "catalog2".to_string(),
351            vec!["schema2".to_string()],
352            "SELECT * FROM table2".to_string(),
353            "client2".to_string(),
354            None,
355        );
356
357        // Test listing processes for specific catalog
358        let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
359        assert_eq!(catalog1_processes.len(), 1);
360        assert_eq!(&catalog1_processes[0].catalog, "catalog1");
361
362        let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
363        assert_eq!(catalog2_processes.len(), 1);
364        assert_eq!(&catalog2_processes[0].catalog, "catalog2");
365
366        // Test listing all processes
367        let all_processes = process_manager.local_processes(None).unwrap();
368        assert_eq!(all_processes.len(), 2);
369    }
370
371    #[tokio::test]
372    async fn test_deregister_query() {
373        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
374
375        let ticket = process_manager.clone().register_query(
376            "public".to_string(),
377            vec!["test".to_string()],
378            "SELECT * FROM table".to_string(),
379            "client1".to_string(),
380            None,
381        );
382        assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
383        process_manager.deregister_query("public".to_string(), ticket.id);
384        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
385    }
386
387    #[tokio::test]
388    async fn test_cancellation_handle() {
389        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
390
391        let ticket = process_manager.clone().register_query(
392            "public".to_string(),
393            vec!["test".to_string()],
394            "SELECT * FROM table".to_string(),
395            "client1".to_string(),
396            None,
397        );
398
399        assert!(!ticket.cancellation_handle.is_cancelled());
400        ticket.cancellation_handle.cancel();
401        assert!(ticket.cancellation_handle.is_cancelled());
402    }
403
404    #[tokio::test]
405    async fn test_kill_local_process() {
406        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
407
408        let ticket = process_manager.clone().register_query(
409            "public".to_string(),
410            vec!["test".to_string()],
411            "SELECT * FROM table".to_string(),
412            "client1".to_string(),
413            None,
414        );
415        assert!(!ticket.cancellation_handle.is_cancelled());
416        let killed = process_manager
417            .kill_process(
418                "127.0.0.1:8000".to_string(),
419                "public".to_string(),
420                ticket.id,
421            )
422            .await
423            .unwrap();
424
425        assert!(killed);
426        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
427    }
428
429    #[tokio::test]
430    async fn test_kill_nonexistent_process() {
431        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
432        let killed = process_manager
433            .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
434            .await
435            .unwrap();
436        assert!(!killed);
437    }
438
439    #[tokio::test]
440    async fn test_kill_process_nonexistent_catalog() {
441        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
442        let killed = process_manager
443            .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
444            .await
445            .unwrap();
446        assert!(!killed);
447    }
448
449    #[tokio::test]
450    async fn test_process_info_fields() {
451        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
452
453        let _ticket = process_manager.clone().register_query(
454            "test_catalog".to_string(),
455            vec!["schema1".to_string(), "schema2".to_string()],
456            "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
457            "test_client".to_string(),
458            Some(42),
459        );
460
461        let processes = process_manager.local_processes(None).unwrap();
462        assert_eq!(processes.len(), 1);
463
464        let process = &processes[0];
465        assert_eq!(process.id, 42);
466        assert_eq!(&process.catalog, "test_catalog");
467        assert_eq!(process.schemas, vec!["schema1", "schema2"]);
468        assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
469        assert_eq!(&process.client, "test_client");
470        assert_eq!(&process.frontend, "127.0.0.1:8000");
471        assert!(process.start_timestamp > 0);
472    }
473
474    #[tokio::test]
475    async fn test_ticket_drop_deregisters_process() {
476        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
477
478        {
479            let _ticket = process_manager.clone().register_query(
480                "public".to_string(),
481                vec!["test".to_string()],
482                "SELECT * FROM table".to_string(),
483                "client1".to_string(),
484                None,
485            );
486
487            // Process should be registered
488            assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
489        } // ticket goes out of scope here
490
491        // Process should be automatically deregistered
492        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
493    }
494}