catalog/
process_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::fmt::{Debug, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20
21use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
22use common_base::cancellation::CancellationHandle;
23use common_frontend::selector::{FrontendSelector, MetaClientSelector};
24use common_telemetry::{debug, info, warn};
25use common_time::util::current_time_millis;
26use meta_client::MetaClientRef;
27use snafu::{ensure, OptionExt, ResultExt};
28
29use crate::error;
30use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
31
32pub type ProcessId = u32;
33pub type ProcessManagerRef = Arc<ProcessManager>;
34
35/// Query process manager.
36pub struct ProcessManager {
37    /// Local frontend server address,
38    server_addr: String,
39    /// Next process id for local queries.
40    next_id: AtomicU32,
41    /// Running process per catalog.
42    catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
43    /// Frontend selector to locate frontend nodes.
44    frontend_selector: Option<MetaClientSelector>,
45}
46
47impl ProcessManager {
48    /// Create a [ProcessManager] instance with server address and kv client.
49    pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
50        let frontend_selector = meta_client.map(MetaClientSelector::new);
51        Self {
52            server_addr,
53            next_id: Default::default(),
54            catalogs: Default::default(),
55            frontend_selector,
56        }
57    }
58}
59
60impl ProcessManager {
61    /// Registers a submitted query. Use the provided id if present.
62    #[must_use]
63    pub fn register_query(
64        self: &Arc<Self>,
65        catalog: String,
66        schemas: Vec<String>,
67        query: String,
68        client: String,
69        query_id: Option<ProcessId>,
70    ) -> Ticket {
71        let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
72        let process = ProcessInfo {
73            id,
74            catalog: catalog.clone(),
75            schemas,
76            query,
77            start_timestamp: current_time_millis(),
78            client,
79            frontend: self.server_addr.clone(),
80        };
81        let cancellation_handle = Arc::new(CancellationHandle::default());
82        let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
83
84        self.catalogs
85            .write()
86            .unwrap()
87            .entry(catalog.clone())
88            .or_default()
89            .insert(id, cancellable_process);
90
91        Ticket {
92            catalog,
93            manager: self.clone(),
94            id,
95            cancellation_handle,
96        }
97    }
98
99    /// Generates the next process id.
100    pub fn next_id(&self) -> u32 {
101        self.next_id.fetch_add(1, Ordering::Relaxed)
102    }
103
104    /// De-register a query from process list.
105    pub fn deregister_query(&self, catalog: String, id: ProcessId) {
106        if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
107            let process = o.get_mut().remove(&id);
108            debug!("Deregister process: {:?}", process);
109            if o.get().is_empty() {
110                o.remove();
111            }
112        }
113    }
114
115    /// List local running processes in given catalog.
116    pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
117        let catalogs = self.catalogs.read().unwrap();
118        let result = if let Some(catalog) = catalog {
119            if let Some(catalogs) = catalogs.get(catalog) {
120                catalogs.values().map(|p| p.process.clone()).collect()
121            } else {
122                vec![]
123            }
124        } else {
125            catalogs
126                .values()
127                .flat_map(|v| v.values().map(|p| p.process.clone()))
128                .collect()
129        };
130        Ok(result)
131    }
132
133    pub async fn list_all_processes(
134        &self,
135        catalog: Option<&str>,
136    ) -> error::Result<Vec<ProcessInfo>> {
137        let mut processes = vec![];
138        if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
139            let frontends = remote_frontend_selector
140                .select(|node| node.peer.addr != self.server_addr)
141                .await
142                .context(error::InvokeFrontendSnafu)?;
143            for mut f in frontends {
144                let result = f
145                    .list_process(ListProcessRequest {
146                        catalog: catalog.unwrap_or_default().to_string(),
147                    })
148                    .await
149                    .context(error::InvokeFrontendSnafu);
150                match result {
151                    Ok(resp) => {
152                        processes.extend(resp.processes);
153                    }
154                    Err(e) => {
155                        warn!(e; "Skipping failing node: {:?}", f)
156                    }
157                }
158            }
159        }
160        processes.extend(self.local_processes(catalog)?);
161        Ok(processes)
162    }
163
164    /// Kills query with provided catalog and id.
165    pub async fn kill_process(
166        &self,
167        server_addr: String,
168        catalog: String,
169        id: ProcessId,
170    ) -> error::Result<bool> {
171        if server_addr == self.server_addr {
172            self.kill_local_process(catalog, id).await
173        } else {
174            let mut nodes = self
175                .frontend_selector
176                .as_ref()
177                .context(error::MetaClientMissingSnafu)?
178                .select(|node| node.peer.addr == server_addr)
179                .await
180                .context(error::InvokeFrontendSnafu)?;
181            ensure!(
182                !nodes.is_empty(),
183                error::FrontendNotFoundSnafu { addr: server_addr }
184            );
185
186            let request = KillProcessRequest {
187                server_addr,
188                catalog,
189                process_id: id,
190            };
191            nodes[0]
192                .kill_process(request)
193                .await
194                .context(error::InvokeFrontendSnafu)?;
195            Ok(true)
196        }
197    }
198
199    /// Kills local query with provided catalog and id.
200    pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
201        if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
202            if let Some(process) = catalogs.remove(&id) {
203                process.handle.cancel();
204                info!(
205                    "Killed process, catalog: {}, id: {:?}",
206                    process.process.catalog, process.process.id
207                );
208                PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
209                Ok(true)
210            } else {
211                debug!("Failed to kill process, id not found: {}", id);
212                Ok(false)
213            }
214        } else {
215            debug!("Failed to kill process, catalog not found: {}", catalog);
216            Ok(false)
217        }
218    }
219}
220
221pub struct Ticket {
222    pub(crate) catalog: String,
223    pub(crate) manager: ProcessManagerRef,
224    pub(crate) id: ProcessId,
225    pub cancellation_handle: Arc<CancellationHandle>,
226}
227
228impl Drop for Ticket {
229    fn drop(&mut self) {
230        self.manager
231            .deregister_query(std::mem::take(&mut self.catalog), self.id);
232    }
233}
234
235struct CancellableProcess {
236    handle: Arc<CancellationHandle>,
237    process: ProcessInfo,
238}
239
240impl Drop for CancellableProcess {
241    fn drop(&mut self) {
242        PROCESS_LIST_COUNT
243            .with_label_values(&[&self.process.catalog])
244            .dec();
245    }
246}
247
248impl CancellableProcess {
249    fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
250        PROCESS_LIST_COUNT
251            .with_label_values(&[&process.catalog])
252            .inc();
253        Self { handle, process }
254    }
255}
256
257impl Debug for CancellableProcess {
258    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("CancellableProcess")
260            .field("cancelled", &self.handle.is_cancelled())
261            .field("process", &self.process)
262            .finish()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use std::sync::Arc;
269
270    use crate::process_manager::ProcessManager;
271
272    #[tokio::test]
273    async fn test_register_query() {
274        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
275        let ticket = process_manager.clone().register_query(
276            "public".to_string(),
277            vec!["test".to_string()],
278            "SELECT * FROM table".to_string(),
279            "".to_string(),
280            None,
281        );
282
283        let running_processes = process_manager.local_processes(None).unwrap();
284        assert_eq!(running_processes.len(), 1);
285        assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
286        assert_eq!(running_processes[0].id, ticket.id);
287        assert_eq!(&running_processes[0].query, "SELECT * FROM table");
288
289        drop(ticket);
290        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
291    }
292
293    #[tokio::test]
294    async fn test_register_query_with_custom_id() {
295        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
296        let custom_id = 12345;
297
298        let ticket = process_manager.clone().register_query(
299            "public".to_string(),
300            vec!["test".to_string()],
301            "SELECT * FROM table".to_string(),
302            "client1".to_string(),
303            Some(custom_id),
304        );
305
306        assert_eq!(ticket.id, custom_id);
307
308        let running_processes = process_manager.local_processes(None).unwrap();
309        assert_eq!(running_processes.len(), 1);
310        assert_eq!(running_processes[0].id, custom_id);
311        assert_eq!(&running_processes[0].client, "client1");
312    }
313
314    #[tokio::test]
315    async fn test_multiple_queries_same_catalog() {
316        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
317
318        let ticket1 = process_manager.clone().register_query(
319            "public".to_string(),
320            vec!["schema1".to_string()],
321            "SELECT * FROM table1".to_string(),
322            "client1".to_string(),
323            None,
324        );
325
326        let ticket2 = process_manager.clone().register_query(
327            "public".to_string(),
328            vec!["schema2".to_string()],
329            "SELECT * FROM table2".to_string(),
330            "client2".to_string(),
331            None,
332        );
333
334        let running_processes = process_manager.local_processes(Some("public")).unwrap();
335        assert_eq!(running_processes.len(), 2);
336
337        // Verify both processes are present
338        let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
339        assert!(ids.contains(&ticket1.id));
340        assert!(ids.contains(&ticket2.id));
341    }
342
343    #[tokio::test]
344    async fn test_multiple_catalogs() {
345        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
346
347        let _ticket1 = process_manager.clone().register_query(
348            "catalog1".to_string(),
349            vec!["schema1".to_string()],
350            "SELECT * FROM table1".to_string(),
351            "client1".to_string(),
352            None,
353        );
354
355        let _ticket2 = process_manager.clone().register_query(
356            "catalog2".to_string(),
357            vec!["schema2".to_string()],
358            "SELECT * FROM table2".to_string(),
359            "client2".to_string(),
360            None,
361        );
362
363        // Test listing processes for specific catalog
364        let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
365        assert_eq!(catalog1_processes.len(), 1);
366        assert_eq!(&catalog1_processes[0].catalog, "catalog1");
367
368        let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
369        assert_eq!(catalog2_processes.len(), 1);
370        assert_eq!(&catalog2_processes[0].catalog, "catalog2");
371
372        // Test listing all processes
373        let all_processes = process_manager.local_processes(None).unwrap();
374        assert_eq!(all_processes.len(), 2);
375    }
376
377    #[tokio::test]
378    async fn test_deregister_query() {
379        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
380
381        let ticket = process_manager.clone().register_query(
382            "public".to_string(),
383            vec!["test".to_string()],
384            "SELECT * FROM table".to_string(),
385            "client1".to_string(),
386            None,
387        );
388        assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
389        process_manager.deregister_query("public".to_string(), ticket.id);
390        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
391    }
392
393    #[tokio::test]
394    async fn test_cancellation_handle() {
395        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
396
397        let ticket = process_manager.clone().register_query(
398            "public".to_string(),
399            vec!["test".to_string()],
400            "SELECT * FROM table".to_string(),
401            "client1".to_string(),
402            None,
403        );
404
405        assert!(!ticket.cancellation_handle.is_cancelled());
406        ticket.cancellation_handle.cancel();
407        assert!(ticket.cancellation_handle.is_cancelled());
408    }
409
410    #[tokio::test]
411    async fn test_kill_local_process() {
412        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
413
414        let ticket = process_manager.clone().register_query(
415            "public".to_string(),
416            vec!["test".to_string()],
417            "SELECT * FROM table".to_string(),
418            "client1".to_string(),
419            None,
420        );
421        assert!(!ticket.cancellation_handle.is_cancelled());
422        let killed = process_manager
423            .kill_process(
424                "127.0.0.1:8000".to_string(),
425                "public".to_string(),
426                ticket.id,
427            )
428            .await
429            .unwrap();
430
431        assert!(killed);
432        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
433    }
434
435    #[tokio::test]
436    async fn test_kill_nonexistent_process() {
437        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
438        let killed = process_manager
439            .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
440            .await
441            .unwrap();
442        assert!(!killed);
443    }
444
445    #[tokio::test]
446    async fn test_kill_process_nonexistent_catalog() {
447        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
448        let killed = process_manager
449            .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
450            .await
451            .unwrap();
452        assert!(!killed);
453    }
454
455    #[tokio::test]
456    async fn test_process_info_fields() {
457        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
458
459        let _ticket = process_manager.clone().register_query(
460            "test_catalog".to_string(),
461            vec!["schema1".to_string(), "schema2".to_string()],
462            "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
463            "test_client".to_string(),
464            Some(42),
465        );
466
467        let processes = process_manager.local_processes(None).unwrap();
468        assert_eq!(processes.len(), 1);
469
470        let process = &processes[0];
471        assert_eq!(process.id, 42);
472        assert_eq!(&process.catalog, "test_catalog");
473        assert_eq!(process.schemas, vec!["schema1", "schema2"]);
474        assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
475        assert_eq!(&process.client, "test_client");
476        assert_eq!(&process.frontend, "127.0.0.1:8000");
477        assert!(process.start_timestamp > 0);
478    }
479
480    #[tokio::test]
481    async fn test_ticket_drop_deregisters_process() {
482        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
483
484        {
485            let _ticket = process_manager.clone().register_query(
486                "public".to_string(),
487                vec!["test".to_string()],
488                "SELECT * FROM table".to_string(),
489                "client1".to_string(),
490                None,
491            );
492
493            // Process should be registered
494            assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
495        } // ticket goes out of scope here
496
497        // Process should be automatically deregistered
498        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
499    }
500}