meta_srv/pubsub/
subscribe_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::info;
18use dashmap::DashMap;
19use tokio::sync::mpsc::Sender;
20
21use crate::error::Result;
22use crate::pubsub::{Message, Subscriber, SubscriberRef, Topic, Transport};
23
24pub trait SubscriptionQuery<T>: Send + Sync {
25    fn subscribers_by_topic(&self, topic: &Topic) -> Vec<SubscriberRef<T>>;
26}
27
28pub trait SubscriptionManager<T>: SubscriptionQuery<T> {
29    fn subscribe(&self, req: SubscribeRequest<T>) -> Result<()>;
30
31    fn unsubscribe(&self, req: UnsubscribeRequest) -> Result<()>;
32
33    fn unsubscribe_all(&self) -> Result<()>;
34}
35
36pub type SubscriptionManagerRef = Arc<dyn SubscriptionManager<Sender<Message>>>;
37
38pub struct SubscribeRequest<T> {
39    pub topics: Vec<Topic>,
40    pub subscriber: Subscriber<T>,
41}
42
43#[derive(Debug, Clone)]
44pub struct UnsubscribeRequest {
45    pub subscriber_id: u32,
46}
47
48pub struct DefaultSubscribeManager<T> {
49    topic_to_subscribers: DashMap<Topic, Vec<Arc<Subscriber<T>>>>,
50}
51
52impl<T> Default for DefaultSubscribeManager<T> {
53    fn default() -> Self {
54        Self {
55            topic_to_subscribers: DashMap::new(),
56        }
57    }
58}
59
60impl<T> SubscriptionQuery<T> for DefaultSubscribeManager<T>
61where
62    T: Transport,
63{
64    fn subscribers_by_topic(&self, topic: &Topic) -> Vec<SubscriberRef<T>> {
65        self.topic_to_subscribers
66            .get(topic)
67            .map(|list_ref| list_ref.clone())
68            .unwrap_or_default()
69    }
70}
71
72impl<T> SubscriptionManager<T> for DefaultSubscribeManager<T>
73where
74    T: Transport,
75{
76    fn subscribe(&self, req: SubscribeRequest<T>) -> Result<()> {
77        let SubscribeRequest { topics, subscriber } = req;
78
79        info!(
80            "Add a subscriber, subscriber_id: {}, subscriber_name: {}, topics: {:?}",
81            subscriber.id(),
82            subscriber.name(),
83            topics
84        );
85
86        let subscriber = Arc::new(subscriber);
87
88        for topic in topics {
89            let mut entry = self.topic_to_subscribers.entry(topic).or_default();
90            entry.push(subscriber.clone());
91        }
92
93        Ok(())
94    }
95
96    fn unsubscribe(&self, req: UnsubscribeRequest) -> Result<()> {
97        let UnsubscribeRequest { subscriber_id } = req;
98
99        info!("Remove a subscriber, subscriber_id: {}", subscriber_id);
100
101        for mut subscribers in self.topic_to_subscribers.iter_mut() {
102            subscribers.retain(|subscriber| subscriber.id() != subscriber_id)
103        }
104
105        Ok(())
106    }
107
108    fn unsubscribe_all(&self) -> Result<()> {
109        self.topic_to_subscribers.clear();
110
111        Ok(())
112    }
113}