1use std::collections::BTreeMap;
20
21use dfir_rs::scheduled::graph::Dfir;
22use dfir_rs::scheduled::graph_ext::GraphExt;
23use dfir_rs::scheduled::port::{PortCtx, SEND};
24use itertools::Itertools;
25use snafu::OptionExt;
26
27use crate::compute::state::{DataflowState, Scheduler};
28use crate::compute::types::{Collection, CollectionBundle, ErrCollector, Toff};
29use crate::error::{Error, InvalidQuerySnafu, NotImplementedSnafu};
30use crate::expr::{self, Batch, GlobalId, LocalId};
31use crate::plan::{Plan, TypedPlan};
32use crate::repr::{self, DiffRow, RelationType};
33
34mod map;
35mod reduce;
36mod src_sink;
37
38pub struct Context<'referred, 'df> {
40 pub id: GlobalId,
41 pub df: &'referred mut Dfir<'df>,
42 pub compute_state: &'referred mut DataflowState,
43 pub input_collection: BTreeMap<GlobalId, CollectionBundle>,
47 pub local_scope: Vec<BTreeMap<LocalId, CollectionBundle>>,
51 pub input_collection_batch: BTreeMap<GlobalId, CollectionBundle<Batch>>,
55 pub local_scope_batch: Vec<BTreeMap<LocalId, CollectionBundle<Batch>>>,
59 pub err_collector: ErrCollector,
61}
62
63impl Drop for Context<'_, '_> {
64 fn drop(&mut self) {
65 for bundle in std::mem::take(&mut self.input_collection)
66 .into_values()
67 .chain(
68 std::mem::take(&mut self.local_scope)
69 .into_iter()
70 .flat_map(|v| v.into_iter())
71 .map(|(_k, v)| v),
72 )
73 {
74 bundle.collection.into_inner().drop(self.df);
75 drop(bundle.arranged);
76 }
77
78 for bundle in std::mem::take(&mut self.input_collection_batch)
79 .into_values()
80 .chain(
81 std::mem::take(&mut self.local_scope_batch)
82 .into_iter()
83 .flat_map(|v| v.into_iter())
84 .map(|(_k, v)| v),
85 )
86 {
87 bundle.collection.into_inner().drop(self.df);
88 drop(bundle.arranged);
89 }
90 }
92}
93
94impl Context<'_, '_> {
95 pub fn insert_global(&mut self, id: GlobalId, collection: CollectionBundle) {
96 self.input_collection.insert(id, collection);
97 }
98
99 pub fn insert_local(&mut self, id: LocalId, collection: CollectionBundle) {
100 if let Some(last) = self.local_scope.last_mut() {
101 last.insert(id, collection);
102 } else {
103 let first = BTreeMap::from([(id, collection)]);
104 self.local_scope.push(first);
105 }
106 }
107
108 pub fn insert_global_batch(&mut self, id: GlobalId, collection: CollectionBundle<Batch>) {
109 self.input_collection_batch.insert(id, collection);
110 }
111
112 pub fn insert_local_batch(&mut self, id: LocalId, collection: CollectionBundle<Batch>) {
113 if let Some(last) = self.local_scope_batch.last_mut() {
114 last.insert(id, collection);
115 } else {
116 let first = BTreeMap::from([(id, collection)]);
117 self.local_scope_batch.push(first);
118 }
119 }
120}
121
122impl Context<'_, '_> {
123 pub fn render_plan_batch(&mut self, plan: TypedPlan) -> Result<CollectionBundle<Batch>, Error> {
125 match plan.plan {
126 Plan::Constant { rows } => Ok(self.render_constant_batch(rows, &plan.schema.typ)),
127 Plan::Get { id } => self.get_batch_by_id(id),
128 Plan::Let { id, value, body } => self.eval_batch_let(id, value, body),
129 Plan::Mfp { input, mfp } => self.render_mfp_batch(input, mfp, &plan.schema.typ),
130 Plan::Reduce {
131 input,
132 key_val_plan,
133 reduce_plan,
134 } => self.render_reduce_batch(input, &key_val_plan, &reduce_plan, &plan.schema.typ),
135 Plan::Join { .. } => NotImplementedSnafu {
136 reason: "Join is still WIP",
137 }
138 .fail(),
139 Plan::Union { .. } => NotImplementedSnafu {
140 reason: "Union is still WIP",
141 }
142 .fail(),
143 }
144 }
145
146 pub fn render_plan(&mut self, plan: TypedPlan) -> Result<CollectionBundle, Error> {
150 match plan.plan {
151 Plan::Constant { rows } => Ok(self.render_constant(rows)),
152 Plan::Get { id } => self.get_by_id(id),
153 Plan::Let { id, value, body } => self.eval_let(id, value, body),
154 Plan::Mfp { input, mfp } => self.render_mfp(input, mfp),
155 Plan::Reduce {
156 input,
157 key_val_plan,
158 reduce_plan,
159 } => self.render_reduce(input, key_val_plan, reduce_plan, plan.schema.typ),
160 Plan::Join { .. } => NotImplementedSnafu {
161 reason: "Join is still WIP",
162 }
163 .fail(),
164 Plan::Union { .. } => NotImplementedSnafu {
165 reason: "Union is still WIP",
166 }
167 .fail(),
168 }
169 }
170
171 pub fn render_constant_batch(
175 &mut self,
176 rows: Vec<DiffRow>,
177 output_type: &RelationType,
178 ) -> CollectionBundle<Batch> {
179 let (send_port, recv_port) = self.df.make_edge::<_, Toff<Batch>>("constant_batch");
180 let mut per_time: BTreeMap<repr::Timestamp, Vec<DiffRow>> = Default::default();
181 for (key, group) in &rows.into_iter().chunk_by(|(_row, ts, _diff)| *ts) {
182 per_time.entry(key).or_default().extend(group);
183 }
184
185 let now = self.compute_state.current_time_ref();
186 let scheduler = self.compute_state.get_scheduler();
188 let scheduler_inner = scheduler.clone();
189 let err_collector = self.err_collector.clone();
190
191 let output_type = output_type.clone();
192
193 let subgraph_id =
194 self.df
195 .add_subgraph_source("ConstantBatch", send_port, move |_ctx, send_port| {
196 let mut after = per_time.split_off(&(*now.borrow() + 1));
200 std::mem::swap(&mut per_time, &mut after);
202 let not_great_than_now = after;
203
204 not_great_than_now.into_iter().for_each(|(_ts, rows)| {
205 err_collector.run(|| {
206 let rows = rows.into_iter().map(|(row, _ts, _diff)| row).collect();
207 let batch = Batch::try_from_rows_with_types(
208 rows,
209 &output_type
210 .column_types
211 .iter()
212 .map(|ty| ty.scalar_type().clone())
213 .collect_vec(),
214 )?;
215 send_port.give(vec![batch]);
216 Ok(())
217 });
218 });
219 if let Some(next_run_time) = per_time.keys().next().copied() {
221 scheduler_inner.schedule_at(next_run_time);
222 }
223 });
224 scheduler.set_cur_subgraph(subgraph_id);
225
226 CollectionBundle::from_collection(Collection::from_port(recv_port))
227 }
228
229 pub fn render_constant(&mut self, rows: Vec<DiffRow>) -> CollectionBundle {
233 let (send_port, recv_port) = self.df.make_edge::<_, Toff>("constant");
234 let mut per_time: BTreeMap<repr::Timestamp, Vec<DiffRow>> = Default::default();
235 for (key, group) in &rows.into_iter().chunk_by(|(_row, ts, _diff)| *ts) {
236 per_time.entry(key).or_default().extend(group);
237 }
238
239 let now = self.compute_state.current_time_ref();
240 let scheduler = self.compute_state.get_scheduler();
242 let scheduler_inner = scheduler.clone();
243
244 let subgraph_id =
245 self.df
246 .add_subgraph_source("Constant", send_port, move |_ctx, send_port| {
247 let mut after = per_time.split_off(&(*now.borrow() + 1));
251 std::mem::swap(&mut per_time, &mut after);
253 let not_great_than_now = after;
254
255 not_great_than_now.into_iter().for_each(|(_ts, rows)| {
256 send_port.give(rows);
257 });
258 if let Some(next_run_time) = per_time.keys().next().copied() {
260 scheduler_inner.schedule_at(next_run_time);
261 }
262 });
263 scheduler.set_cur_subgraph(subgraph_id);
264
265 CollectionBundle::from_collection(Collection::from_port(recv_port))
266 }
267
268 pub fn get_batch_by_id(&mut self, id: expr::Id) -> Result<CollectionBundle<Batch>, Error> {
269 let ret = match id {
270 expr::Id::Local(local) => {
271 let bundle = self
272 .local_scope_batch
273 .iter()
274 .rev()
275 .find_map(|scope| scope.get(&local))
276 .with_context(|| InvalidQuerySnafu {
277 reason: format!("Local variable {:?} not found", local),
278 })?;
279 bundle.clone(self.df)
280 }
281 expr::Id::Global(id) => {
282 let bundle =
283 self.input_collection_batch
284 .get(&id)
285 .with_context(|| InvalidQuerySnafu {
286 reason: format!("Collection {:?} not found", id),
287 })?;
288 bundle.clone(self.df)
289 }
290 };
291 Ok(ret)
292 }
293
294 pub fn get_by_id(&mut self, id: expr::Id) -> Result<CollectionBundle, Error> {
295 let ret = match id {
296 expr::Id::Local(local) => {
297 let bundle = self
298 .local_scope
299 .iter()
300 .rev()
301 .find_map(|scope| scope.get(&local))
302 .with_context(|| InvalidQuerySnafu {
303 reason: format!("Local variable {:?} not found", local),
304 })?;
305 bundle.clone(self.df)
306 }
307 expr::Id::Global(id) => {
308 let bundle = self
309 .input_collection
310 .get(&id)
311 .with_context(|| InvalidQuerySnafu {
312 reason: format!("Collection {:?} not found", id),
313 })?;
314 bundle.clone(self.df)
315 }
316 };
317 Ok(ret)
318 }
319
320 pub fn eval_batch_let(
322 &mut self,
323 id: LocalId,
324 value: Box<TypedPlan>,
325 body: Box<TypedPlan>,
326 ) -> Result<CollectionBundle<Batch>, Error> {
327 let value = self.render_plan_batch(*value)?;
328
329 self.local_scope_batch.push(Default::default());
330 self.insert_local_batch(id, value);
331 let ret = self.render_plan_batch(*body)?;
332 Ok(ret)
333 }
334
335 pub fn eval_let(
337 &mut self,
338 id: LocalId,
339 value: Box<TypedPlan>,
340 body: Box<TypedPlan>,
341 ) -> Result<CollectionBundle, Error> {
342 let value = self.render_plan(*value)?;
343
344 self.local_scope.push(Default::default());
345 self.insert_local(id, value);
346 let ret = self.render_plan(*body)?;
347 Ok(ret)
348 }
349}
350
351struct SubgraphArg<'a, T = Toff> {
353 now: repr::Timestamp,
354 err_collector: &'a ErrCollector,
355 scheduler: &'a Scheduler,
356 send: &'a PortCtx<SEND, T>,
357}
358
359#[cfg(test)]
360mod test {
361 use std::cell::RefCell;
362 use std::rc::Rc;
363
364 use dfir_rs::scheduled::graph::Dfir;
365 use dfir_rs::scheduled::graph_ext::GraphExt;
366 use dfir_rs::scheduled::handoff::VecHandoff;
367 use pretty_assertions::assert_eq;
368
369 use super::*;
370 use crate::repr::Row;
371 pub fn run_and_check(
372 state: &mut DataflowState,
373 df: &mut Dfir,
374 time_range: std::ops::Range<i64>,
375 expected: BTreeMap<i64, Vec<DiffRow>>,
376 output: Rc<RefCell<Vec<DiffRow>>>,
377 ) {
378 for now in time_range {
379 state.set_current_ts(now);
380 state.run_available_with_schedule(df);
381 if !state.get_err_collector().is_empty() {
382 panic!(
383 "Errors occur: {:?}",
384 state.get_err_collector().get_all_blocking()
385 )
386 }
387 assert!(state.get_err_collector().is_empty());
388 if let Some(expected) = expected.get(&now) {
389 assert_eq!(*output.borrow(), *expected, "at ts={}", now);
390 } else {
391 assert_eq!(*output.borrow(), vec![], "at ts={}", now);
392 };
393 output.borrow_mut().clear();
394 }
395 }
396
397 pub fn get_output_handle(
398 ctx: &mut Context,
399 mut bundle: CollectionBundle,
400 ) -> Rc<RefCell<Vec<DiffRow>>> {
401 let collection = bundle.collection;
402 let _arranged = bundle.arranged.pop_first().unwrap().1;
403 let output = Rc::new(RefCell::new(vec![]));
404 let output_inner = output.clone();
405 let _subgraph = ctx.df.add_subgraph_sink(
406 "test_render_constant",
407 collection.into_inner(),
408 move |_ctx, recv| {
409 let data = recv.take_inner();
410 let res = data.into_iter().flat_map(|v| v.into_iter()).collect_vec();
411 output_inner.borrow_mut().clear();
412 output_inner.borrow_mut().extend(res);
413 },
414 );
415 output
416 }
417
418 pub fn harness_test_ctx<'r, 'h>(
419 df: &'r mut Dfir<'h>,
420 state: &'r mut DataflowState,
421 ) -> Context<'r, 'h> {
422 let err_collector = state.get_err_collector();
423 Context {
424 id: GlobalId::User(0),
425 df,
426 compute_state: state,
427 input_collection: BTreeMap::new(),
428 local_scope: Default::default(),
429 input_collection_batch: BTreeMap::new(),
430 local_scope_batch: Default::default(),
431 err_collector,
432 }
433 }
434
435 #[test]
438 fn test_render_constant() {
439 let mut df = Dfir::new();
440 let mut state = DataflowState::default();
441 let mut ctx = harness_test_ctx(&mut df, &mut state);
442
443 let rows = vec![
444 (Row::empty(), 1, 1),
445 (Row::empty(), 2, 1),
446 (Row::empty(), 3, 1),
447 ];
448 let collection = ctx.render_constant(rows);
449 let collection = collection.collection.clone(ctx.df);
450 let cnt = Rc::new(RefCell::new(0));
451 let cnt_inner = cnt.clone();
452 let res_subgraph_id = ctx.df.add_subgraph_sink(
453 "test_render_constant",
454 collection.into_inner(),
455 move |_ctx, recv| {
456 let data = recv.take_inner();
457 *cnt_inner.borrow_mut() += data.iter().map(|v| v.len()).sum::<usize>();
458 },
459 );
460 ctx.compute_state.set_current_ts(2);
461 ctx.compute_state.run_available_with_schedule(ctx.df);
462 assert_eq!(*cnt.borrow(), 2);
463
464 ctx.compute_state.set_current_ts(3);
465 ctx.compute_state.run_available_with_schedule(ctx.df);
466 ctx.df.schedule_subgraph(res_subgraph_id);
468 ctx.df.run_available();
469
470 assert_eq!(*cnt.borrow(), 3);
471 }
472
473 #[test]
475 fn example_source_sink() {
476 let mut df = Dfir::new();
477 let (send_port, recv_port) = df.make_edge::<_, VecHandoff<i32>>("test_handoff");
478 df.add_subgraph_source("test_handoff_source", send_port, move |_ctx, send| {
479 for i in 0..10 {
480 send.give(vec![i]);
481 }
482 });
483
484 let sum = Rc::new(RefCell::new(0));
485 let sum_move = sum.clone();
486 let sink = df.add_subgraph_sink("test_handoff_sink", recv_port, move |_ctx, recv| {
487 let data = recv.take_inner();
488 *sum_move.borrow_mut() += data.iter().sum::<i32>();
489 });
490
491 df.run_available();
492 assert_eq!(sum.borrow().to_owned(), 45);
493 df.schedule_subgraph(sink);
494 df.run_available();
495
496 assert_eq!(sum.borrow().to_owned(), 45);
497 }
498
499 #[test]
500 fn test_tee_auto_schedule() {
501 use dfir_rs::scheduled::handoff::TeeingHandoff as Toff;
502 let mut df = Dfir::new();
503 let (send_port, recv_port) = df.make_edge::<_, Toff<i32>>("test_handoff");
504 let source = df.add_subgraph_source("test_handoff_source", send_port, move |_ctx, send| {
505 for i in 0..10 {
506 send.give(vec![i]);
507 }
508 });
509 let teed_recv_port = recv_port.tee(&mut df);
510
511 let sum = Rc::new(RefCell::new(0));
512 let sum_move = sum.clone();
513 let _sink = df.add_subgraph_sink("test_handoff_sink", teed_recv_port, move |_ctx, recv| {
514 let data = recv.take_inner();
515 *sum_move.borrow_mut() += data.iter().flat_map(|i| i.iter()).sum::<i32>();
516 });
517 drop(recv_port);
518
519 df.run_available();
520 assert_eq!(sum.borrow().to_owned(), 45);
521
522 df.schedule_subgraph(source);
523 df.run_available();
524
525 assert_eq!(sum.borrow().to_owned(), 90);
526 }
527}