use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use common_recordbatch::RecordBatch;
use common_telemetry::trace;
use datatypes::prelude::ConcreteDataType;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
use table::metadata::TableId;
use tokio::sync::{broadcast, mpsc, RwLock};
use crate::adapter::table_source::FlowTableSource;
use crate::adapter::{FlowId, ManagedTableSource, TableName};
use crate::error::{Error, EvalSnafu, TableNotFoundSnafu};
use crate::expr::error::InternalSnafu;
use crate::expr::{Batch, GlobalId};
use crate::metrics::METRIC_FLOW_INPUT_BUF_SIZE;
use crate::plan::TypedPlan;
use crate::repr::{DiffRow, RelationDesc, BATCH_SIZE, BROADCAST_CAP, SEND_BUF_CAP};
#[derive(Debug)]
pub struct FlownodeContext {
pub source_to_tasks: BTreeMap<TableId, BTreeSet<FlowId>>,
pub flow_to_sink: BTreeMap<FlowId, TableName>,
pub flow_plans: BTreeMap<FlowId, TypedPlan>,
pub sink_to_flow: BTreeMap<TableName, FlowId>,
pub source_sender: BTreeMap<TableId, SourceSender>,
pub sink_receiver:
BTreeMap<TableName, (mpsc::UnboundedSender<Batch>, mpsc::UnboundedReceiver<Batch>)>,
pub table_source: Box<dyn FlowTableSource>,
pub table_repr: IdToNameMap,
pub query_context: Option<Arc<QueryContext>>,
}
impl FlownodeContext {
pub fn new(table_source: Box<dyn FlowTableSource>) -> Self {
Self {
source_to_tasks: Default::default(),
flow_to_sink: Default::default(),
flow_plans: Default::default(),
sink_to_flow: Default::default(),
source_sender: Default::default(),
sink_receiver: Default::default(),
table_source,
table_repr: Default::default(),
query_context: Default::default(),
}
}
pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet<FlowId>> {
self.source_to_tasks.get(&table_id)
}
}
#[derive(Debug)]
pub struct SourceSender {
sender: broadcast::Sender<Batch>,
send_buf_tx: mpsc::Sender<Batch>,
send_buf_rx: RwLock<mpsc::Receiver<Batch>>,
send_buf_row_cnt: AtomicUsize,
}
impl Default for SourceSender {
fn default() -> Self {
let (send_buf_tx, send_buf_rx) = mpsc::channel(SEND_BUF_CAP);
Self {
sender: broadcast::Sender::new(SEND_BUF_CAP),
send_buf_tx,
send_buf_rx: RwLock::new(send_buf_rx),
send_buf_row_cnt: AtomicUsize::new(0),
}
}
}
impl SourceSender {
const MAX_ITERATIONS: usize = 16;
pub fn get_receiver(&self) -> broadcast::Receiver<Batch> {
self.sender.subscribe()
}
pub async fn try_flush(&self) -> Result<usize, Error> {
let mut row_cnt = 0;
loop {
let mut send_buf = self.send_buf_rx.write().await;
if self.sender.len() >= BROADCAST_CAP || send_buf.is_empty() {
break;
}
if let Some(batch) = send_buf.recv().await {
let len = batch.row_count();
if let Err(prev_row_cnt) =
self.send_buf_row_cnt
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| x.checked_sub(len))
{
common_telemetry::error!(
"send buf row count underflow, prev = {}, len = {}",
prev_row_cnt,
len
);
}
row_cnt += len;
self.sender
.send(batch)
.map_err(|err| {
InternalSnafu {
reason: format!("Failed to send row, error = {:?}", err),
}
.build()
})
.with_context(|_| EvalSnafu)?;
}
}
if row_cnt > 0 {
trace!("Source Flushed {} rows", row_cnt);
METRIC_FLOW_INPUT_BUF_SIZE.sub(row_cnt as _);
trace!(
"Remaining Source Send buf.len() = {}",
METRIC_FLOW_INPUT_BUF_SIZE.get()
);
}
Ok(row_cnt)
}
pub async fn send_rows(
&self,
rows: Vec<DiffRow>,
batch_datatypes: &[ConcreteDataType],
) -> Result<usize, Error> {
METRIC_FLOW_INPUT_BUF_SIZE.add(rows.len() as _);
while self.send_buf_row_cnt.load(Ordering::SeqCst) >= BATCH_SIZE * 4 {
tokio::task::yield_now().await;
}
let batch = Batch::try_from_rows_with_types(
rows.into_iter().map(|(row, _, _)| row).collect(),
batch_datatypes,
)
.context(EvalSnafu)?;
common_telemetry::trace!("Send one batch to worker with {} rows", batch.row_count());
self.send_buf_row_cnt
.fetch_add(batch.row_count(), Ordering::SeqCst);
self.send_buf_tx.send(batch).await.map_err(|e| {
crate::error::InternalSnafu {
reason: format!("Failed to send row, error = {:?}", e),
}
.build()
})?;
Ok(0)
}
pub async fn send_record_batch(&self, batch: RecordBatch) -> Result<usize, Error> {
let row_cnt = batch.num_rows();
let batch = Batch::from(batch);
self.send_buf_row_cnt.fetch_add(row_cnt, Ordering::SeqCst);
self.send_buf_tx.send(batch).await.map_err(|e| {
crate::error::InternalSnafu {
reason: format!("Failed to send batch, error = {:?}", e),
}
.build()
})?;
Ok(row_cnt)
}
}
impl FlownodeContext {
pub async fn send(
&self,
table_id: TableId,
rows: Vec<DiffRow>,
batch_datatypes: &[ConcreteDataType],
) -> Result<usize, Error> {
let sender = self
.source_sender
.get(&table_id)
.with_context(|| TableNotFoundSnafu {
name: table_id.to_string(),
})?;
sender.send_rows(rows, batch_datatypes).await
}
pub async fn send_rb(&self, table_id: TableId, batch: RecordBatch) -> Result<usize, Error> {
let sender = self
.source_sender
.get(&table_id)
.with_context(|| TableNotFoundSnafu {
name: table_id.to_string(),
})?;
sender.send_record_batch(batch).await
}
pub async fn flush_all_sender(&self) -> Result<usize, Error> {
let mut sum = 0;
for sender in self.source_sender.values() {
sender.try_flush().await.inspect(|x| sum += x)?;
}
Ok(sum)
}
}
impl FlownodeContext {
pub fn register_task_src_sink(
&mut self,
task_id: FlowId,
source_table_ids: &[TableId],
sink_table_name: TableName,
) {
for source_table_id in source_table_ids {
self.add_source_sender_if_not_exist(*source_table_id);
self.source_to_tasks
.entry(*source_table_id)
.or_default()
.insert(task_id);
}
self.add_sink_receiver(sink_table_name.clone());
self.flow_to_sink.insert(task_id, sink_table_name.clone());
self.sink_to_flow.insert(sink_table_name, task_id);
}
pub fn add_flow_plan(&mut self, task_id: FlowId, plan: TypedPlan) {
self.flow_plans.insert(task_id, plan);
}
pub fn get_flow_plan(&self, task_id: &FlowId) -> Option<TypedPlan> {
self.flow_plans.get(task_id).cloned()
}
pub fn remove_flow(&mut self, task_id: FlowId) {
if let Some(sink_table_name) = self.flow_to_sink.remove(&task_id) {
self.sink_to_flow.remove(&sink_table_name);
}
for (source_table_id, tasks) in self.source_to_tasks.iter_mut() {
tasks.remove(&task_id);
if tasks.is_empty() {
self.source_sender.remove(source_table_id);
}
}
self.flow_plans.remove(&task_id);
}
pub fn add_source_sender_if_not_exist(&mut self, table_id: TableId) {
let _sender = self.source_sender.entry(table_id).or_default();
}
pub fn add_sink_receiver(&mut self, table_name: TableName) {
self.sink_receiver
.entry(table_name)
.or_insert_with(mpsc::unbounded_channel);
}
pub fn get_source_by_global_id(&self, id: &GlobalId) -> Result<&SourceSender, Error> {
let table_id = self
.table_repr
.get_by_global_id(id)
.with_context(|| TableNotFoundSnafu {
name: format!("Global Id = {:?}", id),
})?
.1
.with_context(|| TableNotFoundSnafu {
name: format!("Table Id = {:?}", id),
})?;
self.source_sender
.get(&table_id)
.with_context(|| TableNotFoundSnafu {
name: table_id.to_string(),
})
}
pub fn get_sink_by_global_id(
&self,
id: &GlobalId,
) -> Result<mpsc::UnboundedSender<Batch>, Error> {
let table_name = self
.table_repr
.get_by_global_id(id)
.with_context(|| TableNotFoundSnafu {
name: format!("{:?}", id),
})?
.0
.with_context(|| TableNotFoundSnafu {
name: format!("Global Id = {:?}", id),
})?;
self.sink_receiver
.get(&table_name)
.map(|(s, _r)| s.clone())
.with_context(|| TableNotFoundSnafu {
name: table_name.join("."),
})
}
}
impl FlownodeContext {
pub async fn table(&self, name: &TableName) -> Result<(GlobalId, RelationDesc), Error> {
let id = self
.table_repr
.get_by_name(name)
.map(|(_tid, gid)| gid)
.with_context(|| TableNotFoundSnafu {
name: name.join("."),
})?;
let schema = self.table_source.table(name).await?;
Ok((id, schema.relation_desc))
}
pub async fn assign_global_id_to_table(
&mut self,
srv_map: &ManagedTableSource,
mut table_name: Option<TableName>,
table_id: Option<TableId>,
) -> Result<GlobalId, Error> {
if let Some(gid) = table_name
.as_ref()
.and_then(|table_name| self.table_repr.get_by_name(table_name))
.map(|(_, gid)| gid)
.or_else(|| {
table_id
.and_then(|id| self.table_repr.get_by_table_id(&id))
.map(|(_, gid)| gid)
})
{
Ok(gid)
} else {
let global_id = self.new_global_id();
if let Some(table_id) = table_id {
let known_table_name = srv_map.get_table_name(&table_id).await?;
table_name = table_name.or(Some(known_table_name));
} self.table_repr.insert(table_name, table_id, global_id);
Ok(global_id)
}
}
pub fn new_global_id(&self) -> GlobalId {
GlobalId::User(self.table_repr.global_id_to_name_id.len() as u64)
}
}
#[derive(Default, Debug)]
pub struct IdToNameMap {
name_to_global_id: HashMap<TableName, GlobalId>,
id_to_global_id: HashMap<TableId, GlobalId>,
global_id_to_name_id: BTreeMap<GlobalId, (Option<TableName>, Option<TableId>)>,
}
impl IdToNameMap {
pub fn new() -> Self {
Default::default()
}
pub fn insert(&mut self, name: Option<TableName>, id: Option<TableId>, global_id: GlobalId) {
name.clone()
.and_then(|name| self.name_to_global_id.insert(name.clone(), global_id));
id.and_then(|id| self.id_to_global_id.insert(id, global_id));
self.global_id_to_name_id.insert(global_id, (name, id));
}
pub fn get_by_name(&self, name: &TableName) -> Option<(Option<TableId>, GlobalId)> {
self.name_to_global_id.get(name).map(|global_id| {
let (_name, id) = self.global_id_to_name_id.get(global_id).unwrap();
(*id, *global_id)
})
}
pub fn get_by_table_id(&self, id: &TableId) -> Option<(Option<TableName>, GlobalId)> {
self.id_to_global_id.get(id).map(|global_id| {
let (name, _id) = self.global_id_to_name_id.get(global_id).unwrap();
(name.clone(), *global_id)
})
}
pub fn get_by_global_id(
&self,
global_id: &GlobalId,
) -> Option<(Option<TableName>, Option<TableId>)> {
self.global_id_to_name_id.get(global_id).cloned()
}
}