flow/adapter/
worker.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! For single-thread flow worker
16
17use std::collections::{BTreeMap, VecDeque};
18use std::sync::atomic::{AtomicBool, Ordering};
19use std::sync::Arc;
20
21use common_telemetry::info;
22use dfir_rs::scheduled::graph::Dfir;
23use enum_as_inner::EnumAsInner;
24use snafu::ensure;
25use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
26
27use crate::adapter::FlowId;
28use crate::compute::{Context, DataflowState, ErrCollector};
29use crate::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu};
30use crate::expr::{Batch, GlobalId};
31use crate::plan::TypedPlan;
32use crate::repr::{self, DiffRow};
33
34pub type SharedBuf = Arc<Mutex<VecDeque<DiffRow>>>;
35
36type ReqId = usize;
37
38/// Create both worker(`!Send`) and worker handle(`Send + Sync`)
39pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) {
40    let (itc_client, itc_server) = create_inter_thread_call();
41    let worker_handle = WorkerHandle {
42        itc_client,
43        shutdown: AtomicBool::new(false),
44    };
45    let worker = Worker {
46        task_states: BTreeMap::new(),
47        itc_server: Arc::new(Mutex::new(itc_server)),
48    };
49    (worker_handle, worker)
50}
51
52/// ActiveDataflowState is a wrapper around `Dfir` and `DataflowState`
53pub(crate) struct ActiveDataflowState<'subgraph> {
54    df: Dfir<'subgraph>,
55    state: DataflowState,
56    err_collector: ErrCollector,
57}
58
59impl std::fmt::Debug for ActiveDataflowState<'_> {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("ActiveDataflowState")
62            .field("df", &"<Dfir>")
63            .field("state", &self.state)
64            .field("err_collector", &self.err_collector)
65            .finish()
66    }
67}
68
69impl Default for ActiveDataflowState<'_> {
70    fn default() -> Self {
71        ActiveDataflowState {
72            df: Dfir::new(),
73            state: DataflowState::default(),
74            err_collector: ErrCollector::default(),
75        }
76    }
77}
78
79impl<'subgraph> ActiveDataflowState<'subgraph> {
80    /// Create a new render context, assigned with given global id
81    pub fn new_ctx<'ctx>(&'ctx mut self, global_id: GlobalId) -> Context<'ctx, 'subgraph>
82    where
83        'subgraph: 'ctx,
84    {
85        Context {
86            id: global_id,
87            df: &mut self.df,
88            compute_state: &mut self.state,
89            err_collector: self.err_collector.clone(),
90            input_collection: Default::default(),
91            local_scope: Default::default(),
92            input_collection_batch: Default::default(),
93            local_scope_batch: Default::default(),
94        }
95    }
96
97    pub fn set_current_ts(&mut self, ts: repr::Timestamp) {
98        self.state.set_current_ts(ts);
99    }
100
101    pub fn set_last_exec_time(&mut self, ts: repr::Timestamp) {
102        self.state.set_last_exec_time(ts);
103    }
104
105    /// Run all available subgraph
106    ///
107    /// return true if any subgraph actually executed
108    pub fn run_available(&mut self) -> bool {
109        self.state.run_available_with_schedule(&mut self.df)
110    }
111}
112
113#[derive(Debug)]
114pub struct WorkerHandle {
115    itc_client: InterThreadCallClient,
116    shutdown: AtomicBool,
117}
118
119impl WorkerHandle {
120    /// create task, return task id
121    pub async fn create_flow(&self, create_reqs: Request) -> Result<Option<FlowId>, Error> {
122        ensure!(
123            matches!(create_reqs, Request::Create { .. }),
124            InternalSnafu {
125                reason: format!(
126                    "Flow Node/Worker itc failed, expect Request::Create, found {create_reqs:?}"
127                ),
128            }
129        );
130
131        let ret = self.itc_client.call_with_resp(create_reqs).await?;
132        ret.into_create().map_err(|ret| {
133            InternalSnafu {
134                reason: format!(
135                    "Flow Node/Worker itc failed, expect Response::Create, found {ret:?}"
136                ),
137            }
138            .build()
139        })?
140    }
141
142    /// remove task, return task id
143    pub async fn remove_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
144        let req = Request::Remove { flow_id };
145
146        let ret = self.itc_client.call_with_resp(req).await?;
147
148        ret.into_remove().map_err(|ret| {
149            InternalSnafu {
150                reason: format!("Flow Node/Worker failed, expect Response::Remove, found {ret:?}"),
151            }
152            .build()
153        })
154    }
155
156    /// trigger running the worker, will not block, and will run the worker parallelly
157    ///
158    /// will set the current timestamp to `now` for all dataflows before running them
159    ///
160    /// `blocking` indicate whether it will wait til all dataflows are finished computing if true or
161    /// just start computing and return immediately if false
162    ///
163    /// the returned error is unrecoverable, and the worker should be shutdown/rebooted
164    pub async fn run_available(&self, now: repr::Timestamp, blocking: bool) -> Result<(), Error> {
165        common_telemetry::trace!("Running available with blocking={}", blocking);
166        if blocking {
167            let resp = self
168                .itc_client
169                .call_with_resp(Request::RunAvail { now, blocking })
170                .await?;
171            common_telemetry::trace!("Running available with response={:?}", resp);
172            Ok(())
173        } else {
174            self.itc_client
175                .call_no_resp(Request::RunAvail { now, blocking })
176        }
177    }
178
179    pub async fn contains_flow(&self, flow_id: FlowId) -> Result<bool, Error> {
180        let req = Request::ContainTask { flow_id };
181        let ret = self.itc_client.call_with_resp(req).await?;
182
183        ret.into_contain_task().map_err(|ret| {
184            InternalSnafu {
185                reason: format!(
186                    "Flow Node/Worker itc failed, expect Response::ContainTask, found {ret:?}"
187                ),
188            }
189            .build()
190        })
191    }
192
193    /// shutdown the worker
194    pub fn shutdown(&self) -> Result<(), Error> {
195        if !self.shutdown.fetch_or(true, Ordering::SeqCst) {
196            self.itc_client.call_no_resp(Request::Shutdown)
197        } else {
198            UnexpectedSnafu {
199                reason: "Worker already shutdown",
200            }
201            .fail()
202        }
203    }
204
205    pub async fn get_state_size(&self) -> Result<BTreeMap<FlowId, usize>, Error> {
206        let ret = self
207            .itc_client
208            .call_with_resp(Request::QueryStateSize)
209            .await?;
210        ret.into_query_state_size().map_err(|ret| {
211            InternalSnafu {
212                reason: format!(
213                    "Flow Node/Worker itc failed, expect Response::QueryStateSize, found {ret:?}"
214                ),
215            }
216            .build()
217        })
218    }
219
220    pub async fn get_last_exec_time_map(&self) -> Result<BTreeMap<FlowId, i64>, Error> {
221        let ret = self
222            .itc_client
223            .call_with_resp(Request::QueryLastExecTimeMap)
224            .await?;
225        ret.into_query_last_exec_time_map().map_err(|ret| {
226            InternalSnafu {
227                reason: format!(
228                    "Flow Node/Worker get_last_exec_time_map failed, expect Response::QueryLastExecTimeMap, found {ret:?}"
229                ),
230            }
231            .build()
232        })
233    }
234}
235
236impl Drop for WorkerHandle {
237    fn drop(&mut self) {
238        if let Err(ret) = self.shutdown() {
239            common_telemetry::error!(
240                ret;
241                "While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
242            );
243        } else {
244            info!("Flow Worker shutdown due to Worker Handle dropped.")
245        }
246    }
247}
248
249/// The actual worker that does the work and contain active state
250#[derive(Debug)]
251pub struct Worker<'subgraph> {
252    /// Task states
253    pub(crate) task_states: BTreeMap<FlowId, ActiveDataflowState<'subgraph>>,
254    itc_server: Arc<Mutex<InterThreadCallServer>>,
255}
256
257impl<'s> Worker<'s> {
258    #[allow(clippy::too_many_arguments)]
259    pub fn create_flow(
260        &mut self,
261        flow_id: FlowId,
262        plan: TypedPlan,
263        sink_id: GlobalId,
264        sink_sender: mpsc::UnboundedSender<Batch>,
265        source_ids: &[GlobalId],
266        src_recvs: Vec<broadcast::Receiver<Batch>>,
267        // TODO(discord9): set expire duration for all arrangement and compare to sys timestamp instead
268        expire_after: Option<repr::Duration>,
269        or_replace: bool,
270        create_if_not_exists: bool,
271        err_collector: ErrCollector,
272    ) -> Result<Option<FlowId>, Error> {
273        let already_exist = self.task_states.contains_key(&flow_id);
274        match (create_if_not_exists, or_replace, already_exist) {
275            // if replace, ignore that old flow exists
276            (_, true, true) => {
277                info!("Replacing flow with id={}", flow_id);
278            }
279            (false, false, true) => FlowAlreadyExistSnafu { id: flow_id }.fail()?,
280            // already exists, and not replace, return None
281            (true, false, true) => {
282                info!("Flow with id={} already exists, do nothing", flow_id);
283                return Ok(None);
284            }
285            // continue as normal
286            (_, _, false) => (),
287        }
288
289        let mut cur_task_state = ActiveDataflowState::<'s> {
290            err_collector,
291            ..Default::default()
292        };
293        cur_task_state.state.set_expire_after(expire_after);
294
295        {
296            let mut ctx = cur_task_state.new_ctx(sink_id);
297            for (source_id, src_recv) in source_ids.iter().zip(src_recvs) {
298                let bundle = ctx.render_source_batch(src_recv)?;
299                ctx.insert_global_batch(*source_id, bundle);
300            }
301
302            let rendered = ctx.render_plan_batch(plan)?;
303            ctx.render_unbounded_sink_batch(rendered, sink_sender);
304        }
305        self.task_states.insert(flow_id, cur_task_state);
306        Ok(Some(flow_id))
307    }
308
309    /// remove task, return true if a task is removed
310    pub fn remove_flow(&mut self, flow_id: FlowId) -> bool {
311        self.task_states.remove(&flow_id).is_some()
312    }
313
314    /// Run the worker, blocking, until shutdown signal is received
315    pub fn run(&mut self) {
316        loop {
317            let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() {
318                ret
319            } else {
320                common_telemetry::error!(
321                    "Worker's itc server has been closed unexpectedly, shutting down worker now."
322                );
323                break;
324            };
325
326            let ret = self.handle_req(req);
327            match (ret, ret_tx) {
328                (Ok(Some(resp)), Some(ret_tx)) => {
329                    if let Err(err) = ret_tx.send(resp) {
330                        common_telemetry::error!(
331                            err;
332                            "Result receiver is dropped, can't send result"
333                        );
334                    };
335                }
336                (Ok(None), None) => continue,
337                (Ok(Some(resp)), None) => {
338                    common_telemetry::error!(
339                        "Expect no result for current request, but found {resp:?}"
340                    )
341                }
342                (Ok(None), Some(_)) => {
343                    common_telemetry::error!("Expect result for current request, but found nothing")
344                }
345                (Err(()), _) => {
346                    break;
347                }
348            }
349        }
350    }
351
352    /// run with tick acquired from tick manager(usually means system time)
353    /// TODO(discord9): better tick management
354    pub fn run_tick(&mut self, now: repr::Timestamp) {
355        for (_flow_id, task_state) in self.task_states.iter_mut() {
356            task_state.set_current_ts(now);
357            task_state.set_last_exec_time(now);
358            task_state.run_available();
359        }
360    }
361    /// handle request, return response if any, Err if receive shutdown signal
362    ///
363    /// return `Err(())` if receive shutdown request
364    fn handle_req(&mut self, req: Request) -> Result<Option<Response>, ()> {
365        let ret = match req {
366            Request::Create {
367                flow_id,
368                plan,
369                sink_id,
370                sink_sender,
371                source_ids,
372                src_recvs,
373                expire_after,
374                or_replace,
375                create_if_not_exists,
376                err_collector,
377            } => {
378                let task_create_result = self.create_flow(
379                    flow_id,
380                    plan,
381                    sink_id,
382                    sink_sender,
383                    &source_ids,
384                    src_recvs,
385                    expire_after,
386                    or_replace,
387                    create_if_not_exists,
388                    err_collector,
389                );
390                Some(Response::Create {
391                    result: task_create_result,
392                })
393            }
394            Request::Remove { flow_id } => {
395                let ret = self.remove_flow(flow_id);
396                Some(Response::Remove { result: ret })
397            }
398            Request::RunAvail { now, blocking } => {
399                self.run_tick(now);
400                if blocking {
401                    Some(Response::RunAvail)
402                } else {
403                    None
404                }
405            }
406            Request::ContainTask { flow_id } => {
407                let ret = self.task_states.contains_key(&flow_id);
408                Some(Response::ContainTask { result: ret })
409            }
410            Request::Shutdown => return Err(()),
411            Request::QueryStateSize => {
412                let mut ret = BTreeMap::new();
413                for (flow_id, task_state) in self.task_states.iter() {
414                    ret.insert(*flow_id, task_state.state.get_state_size());
415                }
416                Some(Response::QueryStateSize { result: ret })
417            }
418            Request::QueryLastExecTimeMap => {
419                let mut ret = BTreeMap::new();
420                for (flow_id, task_state) in self.task_states.iter() {
421                    if let Some(last_exec_time) = task_state.state.last_exec_time() {
422                        ret.insert(*flow_id, last_exec_time);
423                    }
424                }
425                Some(Response::QueryLastExecTimeMap { result: ret })
426            }
427        };
428        Ok(ret)
429    }
430}
431
432#[derive(Debug, EnumAsInner)]
433pub enum Request {
434    Create {
435        flow_id: FlowId,
436        plan: TypedPlan,
437        sink_id: GlobalId,
438        sink_sender: mpsc::UnboundedSender<Batch>,
439        source_ids: Vec<GlobalId>,
440        src_recvs: Vec<broadcast::Receiver<Batch>>,
441        expire_after: Option<repr::Duration>,
442        or_replace: bool,
443        create_if_not_exists: bool,
444        err_collector: ErrCollector,
445    },
446    Remove {
447        flow_id: FlowId,
448    },
449    /// Trigger the worker to run, useful after input buffer is full
450    RunAvail {
451        now: repr::Timestamp,
452        blocking: bool,
453    },
454    ContainTask {
455        flow_id: FlowId,
456    },
457    Shutdown,
458    QueryStateSize,
459    QueryLastExecTimeMap,
460}
461
462#[derive(Debug, EnumAsInner)]
463enum Response {
464    Create {
465        result: Result<Option<FlowId>, Error>,
466        // TODO(discord9): add flow err_collector
467    },
468    Remove {
469        result: bool,
470    },
471    ContainTask {
472        result: bool,
473    },
474    RunAvail,
475    QueryStateSize {
476        /// each flow tasks' state size
477        result: BTreeMap<FlowId, usize>,
478    },
479    QueryLastExecTimeMap {
480        /// each flow tasks' last execution time
481        result: BTreeMap<FlowId, i64>,
482    },
483}
484
485fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) {
486    let (arg_send, arg_recv) = mpsc::unbounded_channel();
487    let client = InterThreadCallClient {
488        arg_sender: arg_send,
489    };
490    let server = InterThreadCallServer { arg_recv };
491    (client, server)
492}
493
494#[derive(Debug)]
495struct InterThreadCallClient {
496    arg_sender: mpsc::UnboundedSender<(Request, Option<oneshot::Sender<Response>>)>,
497}
498
499impl InterThreadCallClient {
500    /// call without response
501    fn call_no_resp(&self, req: Request) -> Result<(), Error> {
502        self.arg_sender.send((req, None)).map_err(from_send_error)
503    }
504
505    /// call with response
506    async fn call_with_resp(&self, req: Request) -> Result<Response, Error> {
507        let (tx, rx) = oneshot::channel();
508        self.arg_sender
509            .send((req, Some(tx)))
510            .map_err(from_send_error)?;
511        rx.await.map_err(|_| {
512            InternalSnafu {
513                reason: "Sender is dropped",
514            }
515            .build()
516        })
517    }
518}
519
520#[derive(Debug)]
521struct InterThreadCallServer {
522    pub arg_recv: mpsc::UnboundedReceiver<(Request, Option<oneshot::Sender<Response>>)>,
523}
524
525impl InterThreadCallServer {
526    pub async fn recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
527        self.arg_recv.recv().await
528    }
529
530    pub fn blocking_recv(&mut self) -> Option<(Request, Option<oneshot::Sender<Response>>)> {
531        self.arg_recv.blocking_recv()
532    }
533}
534
535fn from_send_error<T>(err: mpsc::error::SendError<T>) -> Error {
536    InternalSnafu {
537        // this `err` will simply display `channel closed`
538        reason: format!(
539            "Worker's receiver channel have been closed unexpected: {}",
540            err
541        ),
542    }
543    .build()
544}
545
546#[cfg(test)]
547mod test {
548    use tokio::sync::oneshot;
549
550    use super::*;
551    use crate::expr::Id;
552    use crate::plan::Plan;
553    use crate::repr::RelationType;
554
555    #[test]
556    fn drop_handle() {
557        let (tx, rx) = oneshot::channel();
558        let worker_thread_handle = std::thread::spawn(move || {
559            let (handle, mut worker) = create_worker();
560            tx.send(handle).unwrap();
561            worker.run();
562        });
563        let handle = rx.blocking_recv().unwrap();
564        drop(handle);
565        worker_thread_handle.join().unwrap();
566    }
567
568    #[tokio::test]
569    pub async fn test_simple_get_with_worker_and_handle() {
570        let (tx, rx) = oneshot::channel();
571        let worker_thread_handle = std::thread::spawn(move || {
572            let (handle, mut worker) = create_worker();
573            tx.send(handle).unwrap();
574            worker.run();
575        });
576        let handle = rx.await.unwrap();
577        let src_ids = vec![GlobalId::User(1)];
578        let (tx, rx) = broadcast::channel::<Batch>(1024);
579        let (sink_tx, mut sink_rx) = mpsc::unbounded_channel::<Batch>();
580        let (flow_id, plan) = (
581            1,
582            TypedPlan {
583                plan: Plan::Get {
584                    id: Id::Global(GlobalId::User(1)),
585                },
586                schema: RelationType::new(vec![]).into_unnamed(),
587            },
588        );
589        let create_reqs = Request::Create {
590            flow_id,
591            plan,
592            sink_id: GlobalId::User(1),
593            sink_sender: sink_tx,
594            source_ids: src_ids,
595            src_recvs: vec![rx],
596            expire_after: None,
597            or_replace: false,
598            create_if_not_exists: true,
599            err_collector: ErrCollector::default(),
600        };
601        assert_eq!(
602            handle.create_flow(create_reqs).await.unwrap(),
603            Some(flow_id)
604        );
605        tx.send(Batch::empty()).unwrap();
606        handle.run_available(0, true).await.unwrap();
607        assert_eq!(handle.get_state_size().await.unwrap().len(), 1);
608        assert_eq!(sink_rx.recv().await.unwrap(), Batch::empty());
609        drop(handle);
610        worker_thread_handle.join().unwrap();
611    }
612}