flow/adapter/
node_context.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Node context, prone to change with every incoming requests
16
17use std::collections::{BTreeMap, BTreeSet, HashMap};
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use common_recordbatch::RecordBatch;
22use common_telemetry::trace;
23use datatypes::prelude::ConcreteDataType;
24use session::context::QueryContext;
25use snafu::{OptionExt, ResultExt};
26use table::metadata::TableId;
27use tokio::sync::{broadcast, mpsc, RwLock};
28
29use crate::adapter::table_source::FlowTableSource;
30use crate::adapter::{FlowId, ManagedTableSource, TableName};
31use crate::error::{Error, EvalSnafu, TableNotFoundSnafu};
32use crate::expr::error::InternalSnafu;
33use crate::expr::{Batch, GlobalId};
34use crate::metrics::METRIC_FLOW_INPUT_BUF_SIZE;
35use crate::plan::TypedPlan;
36use crate::repr::{DiffRow, RelationDesc, BATCH_SIZE, BROADCAST_CAP, SEND_BUF_CAP};
37
38/// A context that holds the information of the dataflow
39#[derive(Debug)]
40pub struct FlownodeContext {
41    /// mapping from source table to tasks, useful for schedule which task to run when a source table is updated
42    pub source_to_tasks: BTreeMap<TableId, BTreeSet<FlowId>>,
43    /// mapping from task to sink table, useful for sending data back to the client when a task is done running
44    pub flow_to_sink: BTreeMap<FlowId, TableName>,
45    pub flow_plans: BTreeMap<FlowId, TypedPlan>,
46    pub sink_to_flow: BTreeMap<TableName, FlowId>,
47    /// broadcast sender for source table, any incoming write request will be sent to the source table's corresponding sender
48    ///
49    /// Note that we are getting insert requests with table id, so we should use table id as the key
50    pub source_sender: BTreeMap<TableId, SourceSender>,
51    /// broadcast receiver for sink table, there should only be one receiver, and it will receive all the data from the sink table
52    ///
53    /// and send it back to the client, since we are mocking the sink table as a client, we should use table name as the key
54    /// note that the sink receiver should only have one, and we are using broadcast as mpsc channel here
55    pub sink_receiver:
56        BTreeMap<TableName, (mpsc::UnboundedSender<Batch>, mpsc::UnboundedReceiver<Batch>)>,
57    /// can query the schema of the table source, from metasrv with local cache
58    pub table_source: Box<dyn FlowTableSource>,
59    /// All the tables that have been registered in the worker
60    pub table_repr: IdToNameMap,
61    pub query_context: Option<Arc<QueryContext>>,
62}
63
64impl FlownodeContext {
65    pub fn new(table_source: Box<dyn FlowTableSource>) -> Self {
66        Self {
67            source_to_tasks: Default::default(),
68            flow_to_sink: Default::default(),
69            flow_plans: Default::default(),
70            sink_to_flow: Default::default(),
71            source_sender: Default::default(),
72            sink_receiver: Default::default(),
73            table_source,
74            table_repr: Default::default(),
75            query_context: Default::default(),
76        }
77    }
78
79    pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet<FlowId>> {
80        self.source_to_tasks.get(&table_id)
81    }
82}
83
84/// a simple broadcast sender with backpressure, bounded capacity and blocking on send when send buf is full
85/// note that it wouldn't evict old data, so it's possible to block forever if the receiver is slow
86///
87/// receiver still use tokio broadcast channel, since only sender side need to know
88/// backpressure and adjust dataflow running duration to avoid blocking
89#[derive(Debug)]
90pub struct SourceSender {
91    // TODO(discord9): make it all Vec<DiffRow>?
92    sender: broadcast::Sender<Batch>,
93    send_buf_tx: mpsc::Sender<Batch>,
94    send_buf_rx: RwLock<mpsc::Receiver<Batch>>,
95    send_buf_row_cnt: AtomicUsize,
96}
97
98impl Default for SourceSender {
99    fn default() -> Self {
100        // TODO(discord9): the capacity is arbitrary, we can adjust it later, might also want to limit the max number of rows in send buf
101        let (send_buf_tx, send_buf_rx) = mpsc::channel(SEND_BUF_CAP);
102        Self {
103            // TODO(discord9): found a better way then increase this to prevent lagging and hence missing input data
104            sender: broadcast::Sender::new(SEND_BUF_CAP),
105            send_buf_tx,
106            send_buf_rx: RwLock::new(send_buf_rx),
107            send_buf_row_cnt: AtomicUsize::new(0),
108        }
109    }
110}
111
112impl SourceSender {
113    /// max number of iterations to try flush send buf
114    const MAX_ITERATIONS: usize = 16;
115    pub fn get_receiver(&self) -> broadcast::Receiver<Batch> {
116        self.sender.subscribe()
117    }
118
119    /// send as many as possible rows from send buf
120    /// until send buf is empty or broadchannel is full
121    pub async fn try_flush(&self) -> Result<usize, Error> {
122        let mut row_cnt = 0;
123        loop {
124            let mut send_buf = self.send_buf_rx.write().await;
125            // if inner sender channel is empty or send buf is empty, there
126            // is nothing to do for now, just break
127            if self.sender.len() >= BROADCAST_CAP || send_buf.is_empty() {
128                break;
129            }
130            // TODO(discord9): send rows instead so it's just moving a point
131            if let Some(batch) = send_buf.recv().await {
132                let len = batch.row_count();
133                if let Err(prev_row_cnt) =
134                    self.send_buf_row_cnt
135                        .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| x.checked_sub(len))
136                {
137                    common_telemetry::error!(
138                        "send buf row count underflow, prev = {}, len = {}",
139                        prev_row_cnt,
140                        len
141                    );
142                }
143                row_cnt += len;
144                self.sender
145                    .send(batch)
146                    .map_err(|err| {
147                        InternalSnafu {
148                            reason: format!("Failed to send row, error = {:?}", err),
149                        }
150                        .build()
151                    })
152                    .with_context(|_| EvalSnafu)?;
153            }
154        }
155        if row_cnt > 0 {
156            trace!("Source Flushed {} rows", row_cnt);
157            METRIC_FLOW_INPUT_BUF_SIZE.sub(row_cnt as _);
158            trace!(
159                "Remaining Source Send buf.len() = {}",
160                METRIC_FLOW_INPUT_BUF_SIZE.get()
161            );
162        }
163
164        Ok(row_cnt)
165    }
166
167    /// return number of rows it actual send(including what's in the buffer)
168    pub async fn send_rows(
169        &self,
170        rows: Vec<DiffRow>,
171        batch_datatypes: &[ConcreteDataType],
172    ) -> Result<usize, Error> {
173        METRIC_FLOW_INPUT_BUF_SIZE.add(rows.len() as _);
174        // important for backpressure. if send buf is full, block until it's not
175        while self.send_buf_row_cnt.load(Ordering::SeqCst) >= BATCH_SIZE * 4 {
176            tokio::task::yield_now().await;
177        }
178
179        // row count metrics is approx so relaxed order is ok
180        let batch = Batch::try_from_rows_with_types(
181            rows.into_iter().map(|(row, _, _)| row).collect(),
182            batch_datatypes,
183        )
184        .context(EvalSnafu)?;
185        common_telemetry::trace!("Send one batch to worker with {} rows", batch.row_count());
186
187        self.send_buf_row_cnt
188            .fetch_add(batch.row_count(), Ordering::SeqCst);
189        self.send_buf_tx.send(batch).await.map_err(|e| {
190            crate::error::InternalSnafu {
191                reason: format!("Failed to send row, error = {:?}", e),
192            }
193            .build()
194        })?;
195
196        Ok(0)
197    }
198
199    /// send record batch
200    pub async fn send_record_batch(&self, batch: RecordBatch) -> Result<usize, Error> {
201        let row_cnt = batch.num_rows();
202        let batch = Batch::from(batch);
203
204        self.send_buf_row_cnt.fetch_add(row_cnt, Ordering::SeqCst);
205
206        self.send_buf_tx.send(batch).await.map_err(|e| {
207            crate::error::InternalSnafu {
208                reason: format!("Failed to send batch, error = {:?}", e),
209            }
210            .build()
211        })?;
212        Ok(row_cnt)
213    }
214}
215
216impl FlownodeContext {
217    /// return number of rows it actual send(including what's in the buffer)
218    ///
219    /// TODO(discord9): make this concurrent
220    pub async fn send(
221        &self,
222        table_id: TableId,
223        rows: Vec<DiffRow>,
224        batch_datatypes: &[ConcreteDataType],
225    ) -> Result<usize, Error> {
226        let sender = self
227            .source_sender
228            .get(&table_id)
229            .with_context(|| TableNotFoundSnafu {
230                name: table_id.to_string(),
231            })?;
232        sender.send_rows(rows, batch_datatypes).await
233    }
234
235    pub async fn send_rb(&self, table_id: TableId, batch: RecordBatch) -> Result<usize, Error> {
236        let sender = self
237            .source_sender
238            .get(&table_id)
239            .with_context(|| TableNotFoundSnafu {
240                name: table_id.to_string(),
241            })?;
242        sender.send_record_batch(batch).await
243    }
244
245    /// flush all sender's buf
246    ///
247    /// return numbers being sent
248    pub async fn flush_all_sender(&self) -> Result<usize, Error> {
249        let mut sum = 0;
250        for sender in self.source_sender.values() {
251            sender.try_flush().await.inspect(|x| sum += x)?;
252        }
253        Ok(sum)
254    }
255}
256
257impl FlownodeContext {
258    /// mapping source table to task, and sink table to task in worker context
259    ///
260    /// also add their corresponding broadcast sender/receiver
261    pub fn register_task_src_sink(
262        &mut self,
263        task_id: FlowId,
264        source_table_ids: &[TableId],
265        sink_table_name: TableName,
266    ) {
267        for source_table_id in source_table_ids {
268            self.add_source_sender_if_not_exist(*source_table_id);
269            self.source_to_tasks
270                .entry(*source_table_id)
271                .or_default()
272                .insert(task_id);
273        }
274
275        self.add_sink_receiver(sink_table_name.clone());
276        self.flow_to_sink.insert(task_id, sink_table_name.clone());
277        self.sink_to_flow.insert(sink_table_name, task_id);
278    }
279
280    /// add flow plan to worker context
281    pub fn add_flow_plan(&mut self, task_id: FlowId, plan: TypedPlan) {
282        self.flow_plans.insert(task_id, plan);
283    }
284
285    pub fn get_flow_plan(&self, task_id: &FlowId) -> Option<TypedPlan> {
286        self.flow_plans.get(task_id).cloned()
287    }
288
289    /// remove flow from worker context
290    pub fn remove_flow(&mut self, task_id: FlowId) {
291        if let Some(sink_table_name) = self.flow_to_sink.remove(&task_id) {
292            self.sink_to_flow.remove(&sink_table_name);
293        }
294        for (source_table_id, tasks) in self.source_to_tasks.iter_mut() {
295            tasks.remove(&task_id);
296            if tasks.is_empty() {
297                self.source_sender.remove(source_table_id);
298            }
299        }
300        self.flow_plans.remove(&task_id);
301    }
302
303    /// try add source sender, if already exist, do nothing
304    pub fn add_source_sender_if_not_exist(&mut self, table_id: TableId) {
305        let _sender = self.source_sender.entry(table_id).or_default();
306    }
307
308    pub fn add_sink_receiver(&mut self, table_name: TableName) {
309        self.sink_receiver
310            .entry(table_name)
311            .or_insert_with(mpsc::unbounded_channel);
312    }
313
314    pub fn get_source_by_global_id(&self, id: &GlobalId) -> Result<&SourceSender, Error> {
315        let table_id = self
316            .table_repr
317            .get_by_global_id(id)
318            .with_context(|| TableNotFoundSnafu {
319                name: format!("Global Id = {:?}", id),
320            })?
321            .1
322            .with_context(|| TableNotFoundSnafu {
323                name: format!("Table Id = {:?}", id),
324            })?;
325        self.source_sender
326            .get(&table_id)
327            .with_context(|| TableNotFoundSnafu {
328                name: table_id.to_string(),
329            })
330    }
331
332    pub fn get_sink_by_global_id(
333        &self,
334        id: &GlobalId,
335    ) -> Result<mpsc::UnboundedSender<Batch>, Error> {
336        let table_name = self
337            .table_repr
338            .get_by_global_id(id)
339            .with_context(|| TableNotFoundSnafu {
340                name: format!("{:?}", id),
341            })?
342            .0
343            .with_context(|| TableNotFoundSnafu {
344                name: format!("Global Id = {:?}", id),
345            })?;
346        self.sink_receiver
347            .get(&table_name)
348            .map(|(s, _r)| s.clone())
349            .with_context(|| TableNotFoundSnafu {
350                name: table_name.join("."),
351            })
352    }
353}
354
355impl FlownodeContext {
356    /// Retrieves a GlobalId and table schema representing a table previously registered by calling the [register_table] function.
357    ///
358    /// Returns an error if no table has been registered with the provided names
359    pub async fn table(&self, name: &TableName) -> Result<(GlobalId, RelationDesc), Error> {
360        let id = self
361            .table_repr
362            .get_by_name(name)
363            .map(|(_tid, gid)| gid)
364            .with_context(|| TableNotFoundSnafu {
365                name: name.join("."),
366            })?;
367        let schema = self.table_source.table(name).await?;
368        Ok((id, schema.relation_desc))
369    }
370
371    /// Assign a global id to a table, if already assigned, return the existing global id
372    ///
373    /// require at least one of `table_name` or `table_id` to be `Some`
374    ///
375    /// and will try to fetch the schema from table info manager(if table exist now)
376    ///
377    /// NOTE: this will not actually render the table into collection referred as GlobalId
378    /// merely creating a mapping from table id to global id
379    pub async fn assign_global_id_to_table(
380        &mut self,
381        srv_map: &ManagedTableSource,
382        mut table_name: Option<TableName>,
383        table_id: Option<TableId>,
384    ) -> Result<GlobalId, Error> {
385        // if we can find by table name/id. not assign it
386        if let Some(gid) = table_name
387            .as_ref()
388            .and_then(|table_name| self.table_repr.get_by_name(table_name))
389            .map(|(_, gid)| gid)
390            .or_else(|| {
391                table_id
392                    .and_then(|id| self.table_repr.get_by_table_id(&id))
393                    .map(|(_, gid)| gid)
394            })
395        {
396            Ok(gid)
397        } else {
398            let global_id = self.new_global_id();
399
400            // table id is Some meaning db must have created the table
401            if let Some(table_id) = table_id {
402                let known_table_name = srv_map.get_table_name(&table_id).await?;
403                table_name = table_name.or(Some(known_table_name));
404            } // if we don't have table id, it means database haven't assign one yet or we don't need it
405
406            // still update the mapping with new global id
407            self.table_repr.insert(table_name, table_id, global_id);
408            Ok(global_id)
409        }
410    }
411
412    /// Get a new global id
413    pub fn new_global_id(&self) -> GlobalId {
414        GlobalId::User(self.table_repr.global_id_to_name_id.len() as u64)
415    }
416}
417
418/// A tri-directional map that maps table name, table id, and global id
419#[derive(Default, Debug)]
420pub struct IdToNameMap {
421    name_to_global_id: HashMap<TableName, GlobalId>,
422    id_to_global_id: HashMap<TableId, GlobalId>,
423    global_id_to_name_id: BTreeMap<GlobalId, (Option<TableName>, Option<TableId>)>,
424}
425
426impl IdToNameMap {
427    pub fn new() -> Self {
428        Default::default()
429    }
430
431    pub fn insert(&mut self, name: Option<TableName>, id: Option<TableId>, global_id: GlobalId) {
432        name.clone()
433            .and_then(|name| self.name_to_global_id.insert(name.clone(), global_id));
434        id.and_then(|id| self.id_to_global_id.insert(id, global_id));
435        self.global_id_to_name_id.insert(global_id, (name, id));
436    }
437
438    pub fn get_by_name(&self, name: &TableName) -> Option<(Option<TableId>, GlobalId)> {
439        self.name_to_global_id.get(name).map(|global_id| {
440            let (_name, id) = self.global_id_to_name_id.get(global_id).unwrap();
441            (*id, *global_id)
442        })
443    }
444
445    pub fn get_by_table_id(&self, id: &TableId) -> Option<(Option<TableName>, GlobalId)> {
446        self.id_to_global_id.get(id).map(|global_id| {
447            let (name, _id) = self.global_id_to_name_id.get(global_id).unwrap();
448            (name.clone(), *global_id)
449        })
450    }
451
452    pub fn get_by_global_id(
453        &self,
454        global_id: &GlobalId,
455    ) -> Option<(Option<TableName>, Option<TableId>)> {
456        self.global_id_to_name_id.get(global_id).cloned()
457    }
458}