1use std::collections::{BTreeMap, BTreeSet, HashMap};
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use common_recordbatch::RecordBatch;
22use common_telemetry::trace;
23use datatypes::prelude::ConcreteDataType;
24use session::context::QueryContext;
25use snafu::{OptionExt, ResultExt};
26use table::metadata::TableId;
27use tokio::sync::{broadcast, mpsc, RwLock};
28
29use crate::adapter::table_source::FlowTableSource;
30use crate::adapter::{FlowId, ManagedTableSource, TableName};
31use crate::error::{Error, EvalSnafu, TableNotFoundSnafu};
32use crate::expr::error::InternalSnafu;
33use crate::expr::{Batch, GlobalId};
34use crate::metrics::METRIC_FLOW_INPUT_BUF_SIZE;
35use crate::plan::TypedPlan;
36use crate::repr::{DiffRow, RelationDesc, BATCH_SIZE, BROADCAST_CAP, SEND_BUF_CAP};
37
38#[derive(Debug)]
40pub struct FlownodeContext {
41 pub source_to_tasks: BTreeMap<TableId, BTreeSet<FlowId>>,
43 pub flow_to_sink: BTreeMap<FlowId, TableName>,
45 pub flow_plans: BTreeMap<FlowId, TypedPlan>,
46 pub sink_to_flow: BTreeMap<TableName, FlowId>,
47 pub source_sender: BTreeMap<TableId, SourceSender>,
51 pub sink_receiver:
56 BTreeMap<TableName, (mpsc::UnboundedSender<Batch>, mpsc::UnboundedReceiver<Batch>)>,
57 pub table_source: Box<dyn FlowTableSource>,
59 pub table_repr: IdToNameMap,
61 pub query_context: Option<Arc<QueryContext>>,
62}
63
64impl FlownodeContext {
65 pub fn new(table_source: Box<dyn FlowTableSource>) -> Self {
66 Self {
67 source_to_tasks: Default::default(),
68 flow_to_sink: Default::default(),
69 flow_plans: Default::default(),
70 sink_to_flow: Default::default(),
71 source_sender: Default::default(),
72 sink_receiver: Default::default(),
73 table_source,
74 table_repr: Default::default(),
75 query_context: Default::default(),
76 }
77 }
78
79 pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet<FlowId>> {
80 self.source_to_tasks.get(&table_id)
81 }
82}
83
84#[derive(Debug)]
90pub struct SourceSender {
91 sender: broadcast::Sender<Batch>,
93 send_buf_tx: mpsc::Sender<Batch>,
94 send_buf_rx: RwLock<mpsc::Receiver<Batch>>,
95 send_buf_row_cnt: AtomicUsize,
96}
97
98impl Default for SourceSender {
99 fn default() -> Self {
100 let (send_buf_tx, send_buf_rx) = mpsc::channel(SEND_BUF_CAP);
102 Self {
103 sender: broadcast::Sender::new(SEND_BUF_CAP),
105 send_buf_tx,
106 send_buf_rx: RwLock::new(send_buf_rx),
107 send_buf_row_cnt: AtomicUsize::new(0),
108 }
109 }
110}
111
112impl SourceSender {
113 const MAX_ITERATIONS: usize = 16;
115 pub fn get_receiver(&self) -> broadcast::Receiver<Batch> {
116 self.sender.subscribe()
117 }
118
119 pub async fn try_flush(&self) -> Result<usize, Error> {
122 let mut row_cnt = 0;
123 loop {
124 let mut send_buf = self.send_buf_rx.write().await;
125 if self.sender.len() >= BROADCAST_CAP || send_buf.is_empty() {
128 break;
129 }
130 if let Some(batch) = send_buf.recv().await {
132 let len = batch.row_count();
133 if let Err(prev_row_cnt) =
134 self.send_buf_row_cnt
135 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| x.checked_sub(len))
136 {
137 common_telemetry::error!(
138 "send buf row count underflow, prev = {}, len = {}",
139 prev_row_cnt,
140 len
141 );
142 }
143 row_cnt += len;
144 self.sender
145 .send(batch)
146 .map_err(|err| {
147 InternalSnafu {
148 reason: format!("Failed to send row, error = {:?}", err),
149 }
150 .build()
151 })
152 .with_context(|_| EvalSnafu)?;
153 }
154 }
155 if row_cnt > 0 {
156 trace!("Source Flushed {} rows", row_cnt);
157 METRIC_FLOW_INPUT_BUF_SIZE.sub(row_cnt as _);
158 trace!(
159 "Remaining Source Send buf.len() = {}",
160 METRIC_FLOW_INPUT_BUF_SIZE.get()
161 );
162 }
163
164 Ok(row_cnt)
165 }
166
167 pub async fn send_rows(
169 &self,
170 rows: Vec<DiffRow>,
171 batch_datatypes: &[ConcreteDataType],
172 ) -> Result<usize, Error> {
173 METRIC_FLOW_INPUT_BUF_SIZE.add(rows.len() as _);
174 while self.send_buf_row_cnt.load(Ordering::SeqCst) >= BATCH_SIZE * 4 {
176 tokio::task::yield_now().await;
177 }
178
179 let batch = Batch::try_from_rows_with_types(
181 rows.into_iter().map(|(row, _, _)| row).collect(),
182 batch_datatypes,
183 )
184 .context(EvalSnafu)?;
185 common_telemetry::trace!("Send one batch to worker with {} rows", batch.row_count());
186
187 self.send_buf_row_cnt
188 .fetch_add(batch.row_count(), Ordering::SeqCst);
189 self.send_buf_tx.send(batch).await.map_err(|e| {
190 crate::error::InternalSnafu {
191 reason: format!("Failed to send row, error = {:?}", e),
192 }
193 .build()
194 })?;
195
196 Ok(0)
197 }
198
199 pub async fn send_record_batch(&self, batch: RecordBatch) -> Result<usize, Error> {
201 let row_cnt = batch.num_rows();
202 let batch = Batch::from(batch);
203
204 self.send_buf_row_cnt.fetch_add(row_cnt, Ordering::SeqCst);
205
206 self.send_buf_tx.send(batch).await.map_err(|e| {
207 crate::error::InternalSnafu {
208 reason: format!("Failed to send batch, error = {:?}", e),
209 }
210 .build()
211 })?;
212 Ok(row_cnt)
213 }
214}
215
216impl FlownodeContext {
217 pub async fn send(
221 &self,
222 table_id: TableId,
223 rows: Vec<DiffRow>,
224 batch_datatypes: &[ConcreteDataType],
225 ) -> Result<usize, Error> {
226 let sender = self
227 .source_sender
228 .get(&table_id)
229 .with_context(|| TableNotFoundSnafu {
230 name: table_id.to_string(),
231 })?;
232 sender.send_rows(rows, batch_datatypes).await
233 }
234
235 pub async fn send_rb(&self, table_id: TableId, batch: RecordBatch) -> Result<usize, Error> {
236 let sender = self
237 .source_sender
238 .get(&table_id)
239 .with_context(|| TableNotFoundSnafu {
240 name: table_id.to_string(),
241 })?;
242 sender.send_record_batch(batch).await
243 }
244
245 pub async fn flush_all_sender(&self) -> Result<usize, Error> {
249 let mut sum = 0;
250 for sender in self.source_sender.values() {
251 sender.try_flush().await.inspect(|x| sum += x)?;
252 }
253 Ok(sum)
254 }
255}
256
257impl FlownodeContext {
258 pub fn register_task_src_sink(
262 &mut self,
263 task_id: FlowId,
264 source_table_ids: &[TableId],
265 sink_table_name: TableName,
266 ) {
267 for source_table_id in source_table_ids {
268 self.add_source_sender_if_not_exist(*source_table_id);
269 self.source_to_tasks
270 .entry(*source_table_id)
271 .or_default()
272 .insert(task_id);
273 }
274
275 self.add_sink_receiver(sink_table_name.clone());
276 self.flow_to_sink.insert(task_id, sink_table_name.clone());
277 self.sink_to_flow.insert(sink_table_name, task_id);
278 }
279
280 pub fn add_flow_plan(&mut self, task_id: FlowId, plan: TypedPlan) {
282 self.flow_plans.insert(task_id, plan);
283 }
284
285 pub fn get_flow_plan(&self, task_id: &FlowId) -> Option<TypedPlan> {
286 self.flow_plans.get(task_id).cloned()
287 }
288
289 pub fn remove_flow(&mut self, task_id: FlowId) {
291 if let Some(sink_table_name) = self.flow_to_sink.remove(&task_id) {
292 self.sink_to_flow.remove(&sink_table_name);
293 }
294 for (source_table_id, tasks) in self.source_to_tasks.iter_mut() {
295 tasks.remove(&task_id);
296 if tasks.is_empty() {
297 self.source_sender.remove(source_table_id);
298 }
299 }
300 self.flow_plans.remove(&task_id);
301 }
302
303 pub fn add_source_sender_if_not_exist(&mut self, table_id: TableId) {
305 let _sender = self.source_sender.entry(table_id).or_default();
306 }
307
308 pub fn add_sink_receiver(&mut self, table_name: TableName) {
309 self.sink_receiver
310 .entry(table_name)
311 .or_insert_with(mpsc::unbounded_channel);
312 }
313
314 pub fn get_source_by_global_id(&self, id: &GlobalId) -> Result<&SourceSender, Error> {
315 let table_id = self
316 .table_repr
317 .get_by_global_id(id)
318 .with_context(|| TableNotFoundSnafu {
319 name: format!("Global Id = {:?}", id),
320 })?
321 .1
322 .with_context(|| TableNotFoundSnafu {
323 name: format!("Table Id = {:?}", id),
324 })?;
325 self.source_sender
326 .get(&table_id)
327 .with_context(|| TableNotFoundSnafu {
328 name: table_id.to_string(),
329 })
330 }
331
332 pub fn get_sink_by_global_id(
333 &self,
334 id: &GlobalId,
335 ) -> Result<mpsc::UnboundedSender<Batch>, Error> {
336 let table_name = self
337 .table_repr
338 .get_by_global_id(id)
339 .with_context(|| TableNotFoundSnafu {
340 name: format!("{:?}", id),
341 })?
342 .0
343 .with_context(|| TableNotFoundSnafu {
344 name: format!("Global Id = {:?}", id),
345 })?;
346 self.sink_receiver
347 .get(&table_name)
348 .map(|(s, _r)| s.clone())
349 .with_context(|| TableNotFoundSnafu {
350 name: table_name.join("."),
351 })
352 }
353}
354
355impl FlownodeContext {
356 pub async fn table(&self, name: &TableName) -> Result<(GlobalId, RelationDesc), Error> {
360 let id = self
361 .table_repr
362 .get_by_name(name)
363 .map(|(_tid, gid)| gid)
364 .with_context(|| TableNotFoundSnafu {
365 name: name.join("."),
366 })?;
367 let schema = self.table_source.table(name).await?;
368 Ok((id, schema.relation_desc))
369 }
370
371 pub async fn assign_global_id_to_table(
380 &mut self,
381 srv_map: &ManagedTableSource,
382 mut table_name: Option<TableName>,
383 table_id: Option<TableId>,
384 ) -> Result<GlobalId, Error> {
385 if let Some(gid) = table_name
387 .as_ref()
388 .and_then(|table_name| self.table_repr.get_by_name(table_name))
389 .map(|(_, gid)| gid)
390 .or_else(|| {
391 table_id
392 .and_then(|id| self.table_repr.get_by_table_id(&id))
393 .map(|(_, gid)| gid)
394 })
395 {
396 Ok(gid)
397 } else {
398 let global_id = self.new_global_id();
399
400 if let Some(table_id) = table_id {
402 let known_table_name = srv_map.get_table_name(&table_id).await?;
403 table_name = table_name.or(Some(known_table_name));
404 } self.table_repr.insert(table_name, table_id, global_id);
408 Ok(global_id)
409 }
410 }
411
412 pub fn new_global_id(&self) -> GlobalId {
414 GlobalId::User(self.table_repr.global_id_to_name_id.len() as u64)
415 }
416}
417
418#[derive(Default, Debug)]
420pub struct IdToNameMap {
421 name_to_global_id: HashMap<TableName, GlobalId>,
422 id_to_global_id: HashMap<TableId, GlobalId>,
423 global_id_to_name_id: BTreeMap<GlobalId, (Option<TableName>, Option<TableId>)>,
424}
425
426impl IdToNameMap {
427 pub fn new() -> Self {
428 Default::default()
429 }
430
431 pub fn insert(&mut self, name: Option<TableName>, id: Option<TableId>, global_id: GlobalId) {
432 name.clone()
433 .and_then(|name| self.name_to_global_id.insert(name.clone(), global_id));
434 id.and_then(|id| self.id_to_global_id.insert(id, global_id));
435 self.global_id_to_name_id.insert(global_id, (name, id));
436 }
437
438 pub fn get_by_name(&self, name: &TableName) -> Option<(Option<TableId>, GlobalId)> {
439 self.name_to_global_id.get(name).map(|global_id| {
440 let (_name, id) = self.global_id_to_name_id.get(global_id).unwrap();
441 (*id, *global_id)
442 })
443 }
444
445 pub fn get_by_table_id(&self, id: &TableId) -> Option<(Option<TableName>, GlobalId)> {
446 self.id_to_global_id.get(id).map(|global_id| {
447 let (name, _id) = self.global_id_to_name_id.get(global_id).unwrap();
448 (name.clone(), *global_id)
449 })
450 }
451
452 pub fn get_by_global_id(
453 &self,
454 global_id: &GlobalId,
455 ) -> Option<(Option<TableName>, Option<TableId>)> {
456 self.global_id_to_name_id.get(global_id).cloned()
457 }
458}