1use std::any::Any;
16use std::fmt;
17use std::fmt::Display;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use common_event_recorder::{Event, Eventable};
23use serde::{Deserialize, Serialize};
24use smallvec::{smallvec, SmallVec};
25use snafu::{ResultExt, Snafu};
26use tokio::sync::watch::Receiver;
27use uuid::Uuid;
28
29use crate::error::{self, Error, Result};
30use crate::local::DynamicKeyLockGuard;
31use crate::watcher::Watcher;
32
33pub type Output = Arc<dyn Any + Send + Sync>;
34
35#[derive(Debug)]
37pub enum Status {
38 Executing {
40 persist: bool,
42 clean_poisons: bool,
44 },
45 Suspended {
47 subprocedures: Vec<ProcedureWithId>,
48 persist: bool,
50 },
51 Poisoned {
53 keys: PoisonKeys,
55 error: Error,
57 },
58 Done { output: Option<Output> },
60}
61
62impl Status {
63 pub fn suspended(subprocedures: Vec<ProcedureWithId>, persist: bool) -> Status {
65 Status::Suspended {
66 subprocedures,
67 persist,
68 }
69 }
70
71 pub fn poisoned(keys: impl IntoIterator<Item = PoisonKey>, error: Error) -> Status {
73 Status::Poisoned {
74 keys: PoisonKeys::new(keys),
75 error,
76 }
77 }
78
79 pub fn executing(persist: bool) -> Status {
81 Status::Executing {
82 persist,
83 clean_poisons: false,
84 }
85 }
86
87 pub fn executing_with_clean_poisons(persist: bool) -> Status {
89 Status::Executing {
90 persist,
91 clean_poisons: true,
92 }
93 }
94
95 pub fn done() -> Status {
97 Status::Done { output: None }
98 }
99
100 #[cfg(any(test, feature = "testing"))]
101 pub fn downcast_output_ref<T: 'static>(&self) -> Option<&T> {
106 if let Status::Done { output } = self {
107 output
108 .as_ref()
109 .expect("Try to downcast the output of Status::Done, but the output is None")
110 .downcast_ref()
111 } else {
112 panic!("Expected the Status::Done, but got: {:?}", self)
113 }
114 }
115
116 pub fn done_with_output<T: Any + Send + Sync>(output: T) -> Status {
118 Status::Done {
119 output: Some(Arc::new(output)),
120 }
121 }
122 pub fn is_done(&self) -> bool {
124 matches!(self, Status::Done { .. })
125 }
126
127 pub fn need_persist(&self) -> bool {
129 match self {
130 Status::Executing { persist, .. } | Status::Suspended { persist, .. } => *persist,
133 Status::Done { .. } | Status::Poisoned { .. } => false,
134 }
135 }
136
137 pub fn need_clean_poisons(&self) -> bool {
139 match self {
140 Status::Executing { clean_poisons, .. } => *clean_poisons,
141 Status::Done { .. } => true,
142 _ => false,
143 }
144 }
145}
146
147#[async_trait]
149pub trait ContextProvider: Send + Sync {
150 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
152
153 async fn procedure_state_receiver(
154 &self,
155 procedure_id: ProcedureId,
156 ) -> Result<Option<Receiver<ProcedureState>>>;
157
158 async fn try_put_poison(&self, key: &PoisonKey, procedure_id: ProcedureId) -> Result<()>;
163
164 async fn acquire_lock(&self, key: &StringKey) -> DynamicKeyLockGuard;
166}
167
168pub type ContextProviderRef = Arc<dyn ContextProvider>;
170
171#[derive(Clone)]
173pub struct Context {
174 pub procedure_id: ProcedureId,
176 pub provider: ContextProviderRef,
178}
179
180#[async_trait]
182pub trait Procedure: Send {
183 fn type_name(&self) -> &str;
185
186 async fn execute(&mut self, ctx: &Context) -> Result<Status>;
190
191 async fn rollback(&mut self, _: &Context) -> Result<()> {
195 error::RollbackNotSupportedSnafu {}.fail()
196 }
197
198 fn rollback_supported(&self) -> bool {
200 false
201 }
202
203 fn dump(&self) -> Result<String>;
205
206 fn recover(&mut self) -> Result<()> {
208 Ok(())
209 }
210
211 fn lock_key(&self) -> LockKey;
213
214 fn poison_keys(&self) -> PoisonKeys {
216 PoisonKeys::default()
217 }
218
219 fn user_metadata(&self) -> Option<UserMetadata> {
221 None
222 }
223}
224
225#[derive(Clone, Debug)]
227pub struct UserMetadata {
228 event_object: Arc<dyn Eventable>,
229}
230
231impl UserMetadata {
232 pub fn new(event_object: Arc<dyn Eventable>) -> Self {
234 Self { event_object }
235 }
236
237 pub fn to_event(&self) -> Option<Box<dyn Event>> {
239 self.event_object.to_event()
240 }
241}
242
243#[async_trait]
244impl<T: Procedure + ?Sized> Procedure for Box<T> {
245 fn type_name(&self) -> &str {
246 (**self).type_name()
247 }
248
249 async fn execute(&mut self, ctx: &Context) -> Result<Status> {
250 (**self).execute(ctx).await
251 }
252
253 async fn rollback(&mut self, ctx: &Context) -> Result<()> {
254 (**self).rollback(ctx).await
255 }
256
257 fn rollback_supported(&self) -> bool {
258 (**self).rollback_supported()
259 }
260
261 fn dump(&self) -> Result<String> {
262 (**self).dump()
263 }
264
265 fn lock_key(&self) -> LockKey {
266 (**self).lock_key()
267 }
268
269 fn poison_keys(&self) -> PoisonKeys {
270 (**self).poison_keys()
271 }
272}
273
274#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
275pub struct PoisonKey(String);
276
277impl Display for PoisonKey {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 write!(f, "{}", self.0)
280 }
281}
282
283impl PoisonKey {
284 pub fn new(key: impl Into<String>) -> Self {
286 Self(key.into())
287 }
288}
289
290#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
294pub struct PoisonKeys(SmallVec<[PoisonKey; 2]>);
295
296impl PoisonKeys {
297 pub fn single(key: impl Into<String>) -> Self {
299 Self(smallvec![PoisonKey::new(key)])
300 }
301
302 pub fn new(keys: impl IntoIterator<Item = PoisonKey>) -> Self {
304 Self(keys.into_iter().collect())
305 }
306
307 pub fn contains(&self, key: &PoisonKey) -> bool {
309 self.0.contains(key)
310 }
311
312 pub fn iter(&self) -> impl Iterator<Item = &PoisonKey> {
314 self.0.iter()
315 }
316}
317
318#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
319pub enum StringKey {
320 Share(String),
321 Exclusive(String),
322}
323
324#[derive(Clone, Debug, Default, PartialEq, Eq)]
330pub struct LockKey(SmallVec<[StringKey; 2]>);
331
332impl StringKey {
333 pub fn into_string(self) -> String {
334 match self {
335 StringKey::Share(s) => s,
336 StringKey::Exclusive(s) => s,
337 }
338 }
339
340 pub fn as_string(&self) -> &String {
341 match self {
342 StringKey::Share(s) => s,
343 StringKey::Exclusive(s) => s,
344 }
345 }
346}
347
348impl LockKey {
349 pub fn single(key: impl Into<StringKey>) -> LockKey {
351 LockKey(smallvec![key.into()])
352 }
353
354 pub fn single_exclusive(key: impl Into<String>) -> LockKey {
356 LockKey(smallvec![StringKey::Exclusive(key.into())])
357 }
358
359 pub fn new(iter: impl IntoIterator<Item = StringKey>) -> LockKey {
361 let mut vec: SmallVec<_> = iter.into_iter().collect();
362 vec.sort();
363 vec.dedup();
365 LockKey(vec)
366 }
367
368 pub fn new_exclusive(iter: impl IntoIterator<Item = String>) -> LockKey {
370 Self::new(iter.into_iter().map(StringKey::Exclusive))
371 }
372
373 pub fn keys_to_lock(&self) -> impl Iterator<Item = &StringKey> {
375 self.0.iter()
376 }
377
378 pub fn get_keys(&self) -> Vec<String> {
380 self.0.iter().map(|key| format!("{:?}", key)).collect()
381 }
382}
383
384pub type BoxedProcedure = Box<dyn Procedure>;
386
387pub struct ProcedureWithId {
389 pub id: ProcedureId,
391 pub procedure: BoxedProcedure,
392}
393
394impl ProcedureWithId {
395 pub fn with_random_id(procedure: BoxedProcedure) -> ProcedureWithId {
398 ProcedureWithId {
399 id: ProcedureId::random(),
400 procedure,
401 }
402 }
403}
404
405impl fmt::Debug for ProcedureWithId {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 write!(f, "{}-{}", self.procedure.type_name(), self.id)
408 }
409}
410
411#[derive(Debug, Snafu)]
412pub struct ParseIdError {
413 source: uuid::Error,
414}
415
416#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
418pub struct ProcedureId(Uuid);
419
420impl ProcedureId {
421 pub fn random() -> ProcedureId {
423 ProcedureId(Uuid::new_v4())
424 }
425
426 pub fn parse_str(input: &str) -> std::result::Result<ProcedureId, ParseIdError> {
428 Uuid::parse_str(input)
429 .map(ProcedureId)
430 .context(ParseIdSnafu)
431 }
432}
433
434impl fmt::Display for ProcedureId {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 write!(f, "{}", self.0)
437 }
438}
439
440impl FromStr for ProcedureId {
441 type Err = ParseIdError;
442
443 fn from_str(s: &str) -> std::result::Result<ProcedureId, ParseIdError> {
444 ProcedureId::parse_str(s)
445 }
446}
447
448pub type BoxedProcedureLoader = Box<dyn Fn(&str) -> Result<BoxedProcedure> + Send>;
450
451#[derive(Debug, Default, Clone)]
453pub enum ProcedureState {
454 #[default]
456 Running,
457 Done { output: Option<Output> },
459 Retrying { error: Arc<Error> },
461 PrepareRollback { error: Arc<Error> },
463 RollingBack { error: Arc<Error> },
465 Failed { error: Arc<Error> },
467 Poisoned { keys: PoisonKeys, error: Arc<Error> },
469}
470
471impl ProcedureState {
472 pub fn failed(error: Arc<Error>) -> ProcedureState {
474 ProcedureState::Failed { error }
475 }
476
477 pub fn prepare_rollback(error: Arc<Error>) -> ProcedureState {
479 ProcedureState::PrepareRollback { error }
480 }
481
482 pub fn rolling_back(error: Arc<Error>) -> ProcedureState {
484 ProcedureState::RollingBack { error }
485 }
486
487 pub fn retrying(error: Arc<Error>) -> ProcedureState {
489 ProcedureState::Retrying { error }
490 }
491
492 pub fn poisoned(keys: PoisonKeys, error: Arc<Error>) -> ProcedureState {
494 ProcedureState::Poisoned { keys, error }
495 }
496
497 pub fn is_running(&self) -> bool {
499 matches!(self, ProcedureState::Running)
500 }
501
502 pub fn is_done(&self) -> bool {
504 matches!(self, ProcedureState::Done { .. })
505 }
506
507 pub fn is_poisoned(&self) -> bool {
509 matches!(self, ProcedureState::Poisoned { .. })
510 }
511
512 pub fn is_failed(&self) -> bool {
514 matches!(self, ProcedureState::Failed { .. })
515 }
516
517 pub fn is_retrying(&self) -> bool {
519 matches!(self, ProcedureState::Retrying { .. })
520 }
521
522 pub fn is_rolling_back(&self) -> bool {
524 matches!(self, ProcedureState::RollingBack { .. })
525 }
526
527 pub fn is_prepare_rollback(&self) -> bool {
529 matches!(self, ProcedureState::PrepareRollback { .. })
530 }
531
532 pub fn error(&self) -> Option<&Arc<Error>> {
534 match self {
535 ProcedureState::Failed { error } => Some(error),
536 ProcedureState::Retrying { error } => Some(error),
537 ProcedureState::RollingBack { error } => Some(error),
538 ProcedureState::Poisoned { error, .. } => Some(error),
539 _ => None,
540 }
541 }
542
543 pub fn as_str_name(&self) -> &str {
545 match self {
546 ProcedureState::Running => "Running",
547 ProcedureState::Done { .. } => "Done",
548 ProcedureState::Retrying { .. } => "Retrying",
549 ProcedureState::Failed { .. } => "Failed",
550 ProcedureState::PrepareRollback { .. } => "PrepareRollback",
551 ProcedureState::RollingBack { .. } => "RollingBack",
552 ProcedureState::Poisoned { .. } => "Poisoned",
553 }
554 }
555}
556
557#[derive(Debug, Clone)]
559pub enum InitProcedureState {
560 Running,
561 RollingBack,
562}
563
564#[async_trait]
567pub trait ProcedureManager: Send + Sync + 'static {
568 fn register_loader(&self, name: &str, loader: BoxedProcedureLoader) -> Result<()>;
570
571 async fn start(&self) -> Result<()>;
577
578 async fn stop(&self) -> Result<()>;
580
581 async fn submit(&self, procedure: ProcedureWithId) -> Result<Watcher>;
585
586 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
590
591 fn procedure_watcher(&self, procedure_id: ProcedureId) -> Option<Watcher>;
593
594 async fn list_procedures(&self) -> Result<Vec<ProcedureInfo>>;
596}
597
598pub type ProcedureManagerRef = Arc<dyn ProcedureManager>;
600
601#[derive(Debug, Clone)]
602pub struct ProcedureInfo {
603 pub id: ProcedureId,
605 pub type_name: String,
607 pub start_time_ms: i64,
609 pub end_time_ms: i64,
611 pub state: ProcedureState,
613 pub lock_keys: Vec<String>,
615}
616
617#[cfg(test)]
618mod tests {
619 use common_error::mock::MockError;
620 use common_error::status_code::StatusCode;
621
622 use super::*;
623
624 #[test]
625 fn test_status() {
626 let status = Status::executing(false);
627 assert!(!status.need_persist());
628
629 let status = Status::executing(true);
630 assert!(status.need_persist());
631
632 let status = Status::executing_with_clean_poisons(false);
633 assert!(status.need_clean_poisons());
634
635 let status = Status::executing_with_clean_poisons(true);
636 assert!(status.need_clean_poisons());
637
638 let status = Status::Suspended {
639 subprocedures: Vec::new(),
640 persist: false,
641 };
642 assert!(!status.need_persist());
643
644 let status = Status::Suspended {
645 subprocedures: Vec::new(),
646 persist: true,
647 };
648 assert!(status.need_persist());
649
650 let status = Status::done();
651 assert!(!status.need_persist());
652 assert!(status.need_clean_poisons());
653 }
654
655 #[test]
656 fn test_lock_key() {
657 let entity = "catalog.schema.my_table";
658 let key = LockKey::single_exclusive(entity);
659 assert_eq!(
660 vec![&StringKey::Exclusive(entity.to_string())],
661 key.keys_to_lock().collect::<Vec<_>>()
662 );
663
664 let key = LockKey::new_exclusive([
665 "b".to_string(),
666 "c".to_string(),
667 "a".to_string(),
668 "c".to_string(),
669 ]);
670 assert_eq!(
671 vec![
672 &StringKey::Exclusive("a".to_string()),
673 &StringKey::Exclusive("b".to_string()),
674 &StringKey::Exclusive("c".to_string())
675 ],
676 key.keys_to_lock().collect::<Vec<_>>()
677 );
678 }
679
680 #[test]
681 fn test_procedure_id() {
682 let id = ProcedureId::random();
683 let uuid_str = id.to_string();
684 assert_eq!(id.0.to_string(), uuid_str);
685
686 let parsed = ProcedureId::parse_str(&uuid_str).unwrap();
687 assert_eq!(id, parsed);
688 let parsed = uuid_str.parse().unwrap();
689 assert_eq!(id, parsed);
690 }
691
692 #[test]
693 fn test_procedure_id_serialization() {
694 let id = ProcedureId::random();
695 let json = serde_json::to_string(&id).unwrap();
696 assert_eq!(format!("\"{id}\""), json);
697
698 let parsed = serde_json::from_str(&json).unwrap();
699 assert_eq!(id, parsed);
700 }
701
702 #[test]
703 fn test_procedure_state() {
704 assert!(ProcedureState::Running.is_running());
705 assert!(ProcedureState::Running.error().is_none());
706 assert!(ProcedureState::Done { output: None }.is_done());
707
708 let state = ProcedureState::failed(Arc::new(Error::external(MockError::new(
709 StatusCode::Unexpected,
710 ))));
711 assert!(state.is_failed());
712 let _ = state.error().unwrap();
713 }
714}