1use std::any::Any;
16use std::fmt;
17use std::fmt::Display;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use smallvec::{smallvec, SmallVec};
24use snafu::{ResultExt, Snafu};
25use uuid::Uuid;
26
27use crate::error::{self, Error, Result};
28use crate::local::DynamicKeyLockGuard;
29use crate::watcher::Watcher;
30
31pub type Output = Arc<dyn Any + Send + Sync>;
32
33#[derive(Debug)]
35pub enum Status {
36 Executing {
38 persist: bool,
40 clean_poisons: bool,
42 },
43 Suspended {
45 subprocedures: Vec<ProcedureWithId>,
46 persist: bool,
48 },
49 Poisoned {
51 keys: PoisonKeys,
53 error: Error,
55 },
56 Done { output: Option<Output> },
58}
59
60impl Status {
61 pub fn poisoned(keys: impl IntoIterator<Item = PoisonKey>, error: Error) -> Status {
63 Status::Poisoned {
64 keys: PoisonKeys::new(keys),
65 error,
66 }
67 }
68
69 pub fn executing(persist: bool) -> Status {
71 Status::Executing {
72 persist,
73 clean_poisons: false,
74 }
75 }
76
77 pub fn executing_with_clean_poisons(persist: bool) -> Status {
79 Status::Executing {
80 persist,
81 clean_poisons: true,
82 }
83 }
84
85 pub fn done() -> Status {
87 Status::Done { output: None }
88 }
89
90 #[cfg(any(test, feature = "testing"))]
91 pub fn downcast_output_ref<T: 'static>(&self) -> Option<&T> {
96 if let Status::Done { output } = self {
97 output
98 .as_ref()
99 .expect("Try to downcast the output of Status::Done, but the output is None")
100 .downcast_ref()
101 } else {
102 panic!("Expected the Status::Done, but got: {:?}", self)
103 }
104 }
105
106 pub fn done_with_output<T: Any + Send + Sync>(output: T) -> Status {
108 Status::Done {
109 output: Some(Arc::new(output)),
110 }
111 }
112 pub fn is_done(&self) -> bool {
114 matches!(self, Status::Done { .. })
115 }
116
117 pub fn need_persist(&self) -> bool {
119 match self {
120 Status::Executing { persist, .. } | Status::Suspended { persist, .. } => *persist,
123 Status::Done { .. } | Status::Poisoned { .. } => false,
124 }
125 }
126
127 pub fn need_clean_poisons(&self) -> bool {
129 match self {
130 Status::Executing { clean_poisons, .. } => *clean_poisons,
131 Status::Done { .. } => true,
132 _ => false,
133 }
134 }
135}
136
137#[async_trait]
139pub trait ContextProvider: Send + Sync {
140 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
142
143 async fn try_put_poison(&self, key: &PoisonKey, procedure_id: ProcedureId) -> Result<()>;
148
149 async fn acquire_lock(&self, key: &StringKey) -> DynamicKeyLockGuard;
151}
152
153pub type ContextProviderRef = Arc<dyn ContextProvider>;
155
156#[derive(Clone)]
158pub struct Context {
159 pub procedure_id: ProcedureId,
161 pub provider: ContextProviderRef,
163}
164
165#[async_trait]
167pub trait Procedure: Send {
168 fn type_name(&self) -> &str;
170
171 async fn execute(&mut self, ctx: &Context) -> Result<Status>;
175
176 async fn rollback(&mut self, _: &Context) -> Result<()> {
180 error::RollbackNotSupportedSnafu {}.fail()
181 }
182
183 fn rollback_supported(&self) -> bool {
185 false
186 }
187
188 fn dump(&self) -> Result<String>;
190
191 fn recover(&mut self) -> Result<()> {
193 Ok(())
194 }
195
196 fn lock_key(&self) -> LockKey;
198
199 fn poison_keys(&self) -> PoisonKeys {
201 PoisonKeys::default()
202 }
203}
204
205#[async_trait]
206impl<T: Procedure + ?Sized> Procedure for Box<T> {
207 fn type_name(&self) -> &str {
208 (**self).type_name()
209 }
210
211 async fn execute(&mut self, ctx: &Context) -> Result<Status> {
212 (**self).execute(ctx).await
213 }
214
215 async fn rollback(&mut self, ctx: &Context) -> Result<()> {
216 (**self).rollback(ctx).await
217 }
218
219 fn rollback_supported(&self) -> bool {
220 (**self).rollback_supported()
221 }
222
223 fn dump(&self) -> Result<String> {
224 (**self).dump()
225 }
226
227 fn lock_key(&self) -> LockKey {
228 (**self).lock_key()
229 }
230
231 fn poison_keys(&self) -> PoisonKeys {
232 (**self).poison_keys()
233 }
234}
235
236#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
237pub struct PoisonKey(String);
238
239impl Display for PoisonKey {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 write!(f, "{}", self.0)
242 }
243}
244
245impl PoisonKey {
246 pub fn new(key: impl Into<String>) -> Self {
248 Self(key.into())
249 }
250}
251
252#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
256pub struct PoisonKeys(SmallVec<[PoisonKey; 2]>);
257
258impl PoisonKeys {
259 pub fn single(key: impl Into<String>) -> Self {
261 Self(smallvec![PoisonKey::new(key)])
262 }
263
264 pub fn new(keys: impl IntoIterator<Item = PoisonKey>) -> Self {
266 Self(keys.into_iter().collect())
267 }
268
269 pub fn contains(&self, key: &PoisonKey) -> bool {
271 self.0.contains(key)
272 }
273
274 pub fn iter(&self) -> impl Iterator<Item = &PoisonKey> {
276 self.0.iter()
277 }
278}
279
280#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
281pub enum StringKey {
282 Share(String),
283 Exclusive(String),
284}
285
286#[derive(Clone, Debug, Default, PartialEq, Eq)]
292pub struct LockKey(SmallVec<[StringKey; 2]>);
293
294impl StringKey {
295 pub fn into_string(self) -> String {
296 match self {
297 StringKey::Share(s) => s,
298 StringKey::Exclusive(s) => s,
299 }
300 }
301
302 pub fn as_string(&self) -> &String {
303 match self {
304 StringKey::Share(s) => s,
305 StringKey::Exclusive(s) => s,
306 }
307 }
308}
309
310impl LockKey {
311 pub fn single(key: impl Into<StringKey>) -> LockKey {
313 LockKey(smallvec![key.into()])
314 }
315
316 pub fn single_exclusive(key: impl Into<String>) -> LockKey {
318 LockKey(smallvec![StringKey::Exclusive(key.into())])
319 }
320
321 pub fn new(iter: impl IntoIterator<Item = StringKey>) -> LockKey {
323 let mut vec: SmallVec<_> = iter.into_iter().collect();
324 vec.sort();
325 vec.dedup();
327 LockKey(vec)
328 }
329
330 pub fn new_exclusive(iter: impl IntoIterator<Item = String>) -> LockKey {
332 Self::new(iter.into_iter().map(StringKey::Exclusive))
333 }
334
335 pub fn keys_to_lock(&self) -> impl Iterator<Item = &StringKey> {
337 self.0.iter()
338 }
339
340 pub fn get_keys(&self) -> Vec<String> {
342 self.0.iter().map(|key| format!("{:?}", key)).collect()
343 }
344}
345
346pub type BoxedProcedure = Box<dyn Procedure>;
348
349pub struct ProcedureWithId {
351 pub id: ProcedureId,
353 pub procedure: BoxedProcedure,
354}
355
356impl ProcedureWithId {
357 pub fn with_random_id(procedure: BoxedProcedure) -> ProcedureWithId {
360 ProcedureWithId {
361 id: ProcedureId::random(),
362 procedure,
363 }
364 }
365}
366
367impl fmt::Debug for ProcedureWithId {
368 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369 write!(f, "{}-{}", self.procedure.type_name(), self.id)
370 }
371}
372
373#[derive(Debug, Snafu)]
374pub struct ParseIdError {
375 source: uuid::Error,
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
380pub struct ProcedureId(Uuid);
381
382impl ProcedureId {
383 pub fn random() -> ProcedureId {
385 ProcedureId(Uuid::new_v4())
386 }
387
388 pub fn parse_str(input: &str) -> std::result::Result<ProcedureId, ParseIdError> {
390 Uuid::parse_str(input)
391 .map(ProcedureId)
392 .context(ParseIdSnafu)
393 }
394}
395
396impl fmt::Display for ProcedureId {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 write!(f, "{}", self.0)
399 }
400}
401
402impl FromStr for ProcedureId {
403 type Err = ParseIdError;
404
405 fn from_str(s: &str) -> std::result::Result<ProcedureId, ParseIdError> {
406 ProcedureId::parse_str(s)
407 }
408}
409
410pub type BoxedProcedureLoader = Box<dyn Fn(&str) -> Result<BoxedProcedure> + Send>;
412
413#[derive(Debug, Default, Clone)]
415pub enum ProcedureState {
416 #[default]
418 Running,
419 Done { output: Option<Output> },
421 Retrying { error: Arc<Error> },
423 PrepareRollback { error: Arc<Error> },
425 RollingBack { error: Arc<Error> },
427 Failed { error: Arc<Error> },
429 Poisoned { keys: PoisonKeys, error: Arc<Error> },
431}
432
433impl ProcedureState {
434 pub fn failed(error: Arc<Error>) -> ProcedureState {
436 ProcedureState::Failed { error }
437 }
438
439 pub fn prepare_rollback(error: Arc<Error>) -> ProcedureState {
441 ProcedureState::PrepareRollback { error }
442 }
443
444 pub fn rolling_back(error: Arc<Error>) -> ProcedureState {
446 ProcedureState::RollingBack { error }
447 }
448
449 pub fn retrying(error: Arc<Error>) -> ProcedureState {
451 ProcedureState::Retrying { error }
452 }
453
454 pub fn poisoned(keys: PoisonKeys, error: Arc<Error>) -> ProcedureState {
456 ProcedureState::Poisoned { keys, error }
457 }
458
459 pub fn is_running(&self) -> bool {
461 matches!(self, ProcedureState::Running)
462 }
463
464 pub fn is_done(&self) -> bool {
466 matches!(self, ProcedureState::Done { .. })
467 }
468
469 pub fn is_poisoned(&self) -> bool {
471 matches!(self, ProcedureState::Poisoned { .. })
472 }
473
474 pub fn is_failed(&self) -> bool {
476 matches!(self, ProcedureState::Failed { .. })
477 }
478
479 pub fn is_retrying(&self) -> bool {
481 matches!(self, ProcedureState::Retrying { .. })
482 }
483
484 pub fn is_rolling_back(&self) -> bool {
486 matches!(self, ProcedureState::RollingBack { .. })
487 }
488
489 pub fn is_prepare_rollback(&self) -> bool {
491 matches!(self, ProcedureState::PrepareRollback { .. })
492 }
493
494 pub fn error(&self) -> Option<&Arc<Error>> {
496 match self {
497 ProcedureState::Failed { error } => Some(error),
498 ProcedureState::Retrying { error } => Some(error),
499 ProcedureState::RollingBack { error } => Some(error),
500 ProcedureState::Poisoned { error, .. } => Some(error),
501 _ => None,
502 }
503 }
504
505 pub fn as_str_name(&self) -> &str {
507 match self {
508 ProcedureState::Running => "Running",
509 ProcedureState::Done { .. } => "Done",
510 ProcedureState::Retrying { .. } => "Retrying",
511 ProcedureState::Failed { .. } => "Failed",
512 ProcedureState::PrepareRollback { .. } => "PrepareRollback",
513 ProcedureState::RollingBack { .. } => "RollingBack",
514 ProcedureState::Poisoned { .. } => "Poisoned",
515 }
516 }
517}
518
519#[derive(Debug, Clone)]
521pub enum InitProcedureState {
522 Running,
523 RollingBack,
524}
525
526#[async_trait]
529pub trait ProcedureManager: Send + Sync + 'static {
530 fn register_loader(&self, name: &str, loader: BoxedProcedureLoader) -> Result<()>;
532
533 async fn start(&self) -> Result<()>;
539
540 async fn stop(&self) -> Result<()>;
542
543 async fn submit(&self, procedure: ProcedureWithId) -> Result<Watcher>;
547
548 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
552
553 fn procedure_watcher(&self, procedure_id: ProcedureId) -> Option<Watcher>;
555
556 async fn list_procedures(&self) -> Result<Vec<ProcedureInfo>>;
558}
559
560pub type ProcedureManagerRef = Arc<dyn ProcedureManager>;
562
563#[derive(Debug, Clone)]
564pub struct ProcedureInfo {
565 pub id: ProcedureId,
567 pub type_name: String,
569 pub start_time_ms: i64,
571 pub end_time_ms: i64,
573 pub state: ProcedureState,
575 pub lock_keys: Vec<String>,
577}
578
579#[cfg(test)]
580mod tests {
581 use common_error::mock::MockError;
582 use common_error::status_code::StatusCode;
583
584 use super::*;
585
586 #[test]
587 fn test_status() {
588 let status = Status::executing(false);
589 assert!(!status.need_persist());
590
591 let status = Status::executing(true);
592 assert!(status.need_persist());
593
594 let status = Status::executing_with_clean_poisons(false);
595 assert!(status.need_clean_poisons());
596
597 let status = Status::executing_with_clean_poisons(true);
598 assert!(status.need_clean_poisons());
599
600 let status = Status::Suspended {
601 subprocedures: Vec::new(),
602 persist: false,
603 };
604 assert!(!status.need_persist());
605
606 let status = Status::Suspended {
607 subprocedures: Vec::new(),
608 persist: true,
609 };
610 assert!(status.need_persist());
611
612 let status = Status::done();
613 assert!(!status.need_persist());
614 assert!(status.need_clean_poisons());
615 }
616
617 #[test]
618 fn test_lock_key() {
619 let entity = "catalog.schema.my_table";
620 let key = LockKey::single_exclusive(entity);
621 assert_eq!(
622 vec![&StringKey::Exclusive(entity.to_string())],
623 key.keys_to_lock().collect::<Vec<_>>()
624 );
625
626 let key = LockKey::new_exclusive([
627 "b".to_string(),
628 "c".to_string(),
629 "a".to_string(),
630 "c".to_string(),
631 ]);
632 assert_eq!(
633 vec![
634 &StringKey::Exclusive("a".to_string()),
635 &StringKey::Exclusive("b".to_string()),
636 &StringKey::Exclusive("c".to_string())
637 ],
638 key.keys_to_lock().collect::<Vec<_>>()
639 );
640 }
641
642 #[test]
643 fn test_procedure_id() {
644 let id = ProcedureId::random();
645 let uuid_str = id.to_string();
646 assert_eq!(id.0.to_string(), uuid_str);
647
648 let parsed = ProcedureId::parse_str(&uuid_str).unwrap();
649 assert_eq!(id, parsed);
650 let parsed = uuid_str.parse().unwrap();
651 assert_eq!(id, parsed);
652 }
653
654 #[test]
655 fn test_procedure_id_serialization() {
656 let id = ProcedureId::random();
657 let json = serde_json::to_string(&id).unwrap();
658 assert_eq!(format!("\"{id}\""), json);
659
660 let parsed = serde_json::from_str(&json).unwrap();
661 assert_eq!(id, parsed);
662 }
663
664 #[test]
665 fn test_procedure_state() {
666 assert!(ProcedureState::Running.is_running());
667 assert!(ProcedureState::Running.error().is_none());
668 assert!(ProcedureState::Done { output: None }.is_done());
669
670 let state = ProcedureState::failed(Arc::new(Error::external(MockError::new(
671 StatusCode::Unexpected,
672 ))));
673 assert!(state.is_failed());
674 let _ = state.error().unwrap();
675 }
676}