1use std::any::Any;
16use std::fmt;
17use std::fmt::Display;
18use std::str::FromStr;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use common_event_recorder::{Event, Eventable};
23use serde::{Deserialize, Serialize};
24use smallvec::{SmallVec, smallvec};
25use snafu::{ResultExt, Snafu};
26use tokio::sync::watch::Receiver;
27use uuid::Uuid;
28
29use crate::error::{self, Error, Result};
30use crate::local::DynamicKeyLockGuard;
31use crate::watcher::Watcher;
32
33pub type Output = Arc<dyn Any + Send + Sync>;
34
35#[derive(Debug)]
37pub enum Status {
38 Executing {
40 persist: bool,
42 clean_poisons: bool,
44 },
45 Suspended {
47 subprocedures: Vec<ProcedureWithId>,
48 persist: bool,
50 },
51 Poisoned {
53 keys: PoisonKeys,
55 error: Error,
57 },
58 Done { output: Option<Output> },
60}
61
62impl Status {
63 pub fn suspended(subprocedures: Vec<ProcedureWithId>, persist: bool) -> Status {
65 Status::Suspended {
66 subprocedures,
67 persist,
68 }
69 }
70
71 pub fn poisoned(keys: impl IntoIterator<Item = PoisonKey>, error: Error) -> Status {
73 Status::Poisoned {
74 keys: PoisonKeys::new(keys),
75 error,
76 }
77 }
78
79 pub fn executing(persist: bool) -> Status {
81 Status::Executing {
82 persist,
83 clean_poisons: false,
84 }
85 }
86
87 pub fn executing_with_clean_poisons(persist: bool) -> Status {
89 Status::Executing {
90 persist,
91 clean_poisons: true,
92 }
93 }
94
95 pub fn done() -> Status {
97 Status::Done { output: None }
98 }
99
100 #[cfg(any(test, feature = "testing"))]
101 pub fn downcast_output_ref<T: 'static>(&self) -> Option<&T> {
106 if let Status::Done { output } = self {
107 output
108 .as_ref()
109 .expect("Try to downcast the output of Status::Done, but the output is None")
110 .downcast_ref()
111 } else {
112 panic!("Expected the Status::Done, but got: {:?}", self)
113 }
114 }
115
116 pub fn done_with_output<T: Any + Send + Sync>(output: T) -> Status {
118 Status::Done {
119 output: Some(Arc::new(output)),
120 }
121 }
122 pub fn is_done(&self) -> bool {
124 matches!(self, Status::Done { .. })
125 }
126
127 pub fn need_persist(&self) -> bool {
129 match self {
130 Status::Executing { persist, .. } | Status::Suspended { persist, .. } => *persist,
133 Status::Done { .. } | Status::Poisoned { .. } => false,
134 }
135 }
136
137 pub fn need_clean_poisons(&self) -> bool {
139 match self {
140 Status::Executing { clean_poisons, .. } => *clean_poisons,
141 Status::Done { .. } => true,
142 _ => false,
143 }
144 }
145}
146
147#[async_trait]
149pub trait ContextProvider: Send + Sync {
150 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
152
153 async fn procedure_state_receiver(
154 &self,
155 procedure_id: ProcedureId,
156 ) -> Result<Option<Receiver<ProcedureState>>>;
157
158 async fn try_put_poison(&self, key: &PoisonKey, procedure_id: ProcedureId) -> Result<()>;
163
164 async fn acquire_lock(&self, key: &StringKey) -> DynamicKeyLockGuard;
166}
167
168pub type ContextProviderRef = Arc<dyn ContextProvider>;
170
171#[derive(Clone)]
173pub struct Context {
174 pub procedure_id: ProcedureId,
176 pub provider: ContextProviderRef,
178}
179
180impl Context {
181 pub async fn is_retrying(&self) -> Option<bool> {
183 self.provider
184 .procedure_state(self.procedure_id)
185 .await
186 .ok()
187 .flatten()
188 .map(|s| s.is_retrying())
189 }
190}
191
192#[async_trait]
194pub trait Procedure: Send {
195 fn type_name(&self) -> &str;
197
198 async fn execute(&mut self, ctx: &Context) -> Result<Status>;
202
203 async fn rollback(&mut self, _: &Context) -> Result<()> {
207 error::RollbackNotSupportedSnafu {}.fail()
208 }
209
210 fn rollback_supported(&self) -> bool {
212 false
213 }
214
215 fn dump(&self) -> Result<String>;
217
218 fn recover(&mut self) -> Result<()> {
220 Ok(())
221 }
222
223 fn lock_key(&self) -> LockKey;
225
226 fn poison_keys(&self) -> PoisonKeys {
228 PoisonKeys::default()
229 }
230
231 fn user_metadata(&self) -> Option<UserMetadata> {
233 None
234 }
235}
236
237#[derive(Clone, Debug)]
239pub struct UserMetadata {
240 event_object: Arc<dyn Eventable>,
241}
242
243impl UserMetadata {
244 pub fn new(event_object: Arc<dyn Eventable>) -> Self {
246 Self { event_object }
247 }
248
249 pub fn to_event(&self) -> Option<Box<dyn Event>> {
251 self.event_object.to_event()
252 }
253}
254
255#[async_trait]
256impl<T: Procedure + ?Sized> Procedure for Box<T> {
257 fn type_name(&self) -> &str {
258 (**self).type_name()
259 }
260
261 async fn execute(&mut self, ctx: &Context) -> Result<Status> {
262 (**self).execute(ctx).await
263 }
264
265 async fn rollback(&mut self, ctx: &Context) -> Result<()> {
266 (**self).rollback(ctx).await
267 }
268
269 fn rollback_supported(&self) -> bool {
270 (**self).rollback_supported()
271 }
272
273 fn dump(&self) -> Result<String> {
274 (**self).dump()
275 }
276
277 fn lock_key(&self) -> LockKey {
278 (**self).lock_key()
279 }
280
281 fn poison_keys(&self) -> PoisonKeys {
282 (**self).poison_keys()
283 }
284}
285
286#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
287pub struct PoisonKey(String);
288
289impl Display for PoisonKey {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 write!(f, "{}", self.0)
292 }
293}
294
295impl PoisonKey {
296 pub fn new(key: impl Into<String>) -> Self {
298 Self(key.into())
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
306pub struct PoisonKeys(SmallVec<[PoisonKey; 2]>);
307
308impl PoisonKeys {
309 pub fn single(key: impl Into<String>) -> Self {
311 Self(smallvec![PoisonKey::new(key)])
312 }
313
314 pub fn new(keys: impl IntoIterator<Item = PoisonKey>) -> Self {
316 Self(keys.into_iter().collect())
317 }
318
319 pub fn contains(&self, key: &PoisonKey) -> bool {
321 self.0.contains(key)
322 }
323
324 pub fn iter(&self) -> impl Iterator<Item = &PoisonKey> {
326 self.0.iter()
327 }
328}
329
330#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
331pub enum StringKey {
332 Share(String),
333 Exclusive(String),
334}
335
336#[derive(Clone, Debug, Default, PartialEq, Eq)]
342pub struct LockKey(SmallVec<[StringKey; 2]>);
343
344impl StringKey {
345 pub fn into_string(self) -> String {
346 match self {
347 StringKey::Share(s) => s,
348 StringKey::Exclusive(s) => s,
349 }
350 }
351
352 pub fn as_string(&self) -> &String {
353 match self {
354 StringKey::Share(s) => s,
355 StringKey::Exclusive(s) => s,
356 }
357 }
358}
359
360impl LockKey {
361 pub fn single(key: impl Into<StringKey>) -> LockKey {
363 LockKey(smallvec![key.into()])
364 }
365
366 pub fn single_exclusive(key: impl Into<String>) -> LockKey {
368 LockKey(smallvec![StringKey::Exclusive(key.into())])
369 }
370
371 pub fn new(iter: impl IntoIterator<Item = StringKey>) -> LockKey {
373 let mut vec: SmallVec<_> = iter.into_iter().collect();
374 vec.sort();
375 vec.dedup();
377 LockKey(vec)
378 }
379
380 pub fn new_exclusive(iter: impl IntoIterator<Item = String>) -> LockKey {
382 Self::new(iter.into_iter().map(StringKey::Exclusive))
383 }
384
385 pub fn keys_to_lock(&self) -> impl Iterator<Item = &StringKey> {
387 self.0.iter()
388 }
389
390 pub fn get_keys(&self) -> Vec<String> {
392 self.0.iter().map(|key| format!("{:?}", key)).collect()
393 }
394}
395
396pub type BoxedProcedure = Box<dyn Procedure>;
398
399pub struct ProcedureWithId {
401 pub id: ProcedureId,
403 pub procedure: BoxedProcedure,
404}
405
406impl ProcedureWithId {
407 pub fn with_random_id(procedure: BoxedProcedure) -> ProcedureWithId {
410 ProcedureWithId {
411 id: ProcedureId::random(),
412 procedure,
413 }
414 }
415}
416
417impl fmt::Debug for ProcedureWithId {
418 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419 write!(f, "{}-{}", self.procedure.type_name(), self.id)
420 }
421}
422
423#[derive(Debug, Snafu)]
424pub struct ParseIdError {
425 source: uuid::Error,
426}
427
428#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
430pub struct ProcedureId(Uuid);
431
432impl ProcedureId {
433 pub fn random() -> ProcedureId {
435 ProcedureId(Uuid::new_v4())
436 }
437
438 pub fn parse_str(input: &str) -> std::result::Result<ProcedureId, ParseIdError> {
440 Uuid::parse_str(input)
441 .map(ProcedureId)
442 .context(ParseIdSnafu)
443 }
444}
445
446impl fmt::Display for ProcedureId {
447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448 write!(f, "{}", self.0)
449 }
450}
451
452impl FromStr for ProcedureId {
453 type Err = ParseIdError;
454
455 fn from_str(s: &str) -> std::result::Result<ProcedureId, ParseIdError> {
456 ProcedureId::parse_str(s)
457 }
458}
459
460pub type BoxedProcedureLoader = Box<dyn Fn(&str) -> Result<BoxedProcedure> + Send>;
462
463#[derive(Debug, Default, Clone)]
465pub enum ProcedureState {
466 #[default]
468 Running,
469 Done { output: Option<Output> },
471 Retrying { error: Arc<Error> },
473 PrepareRollback { error: Arc<Error> },
475 RollingBack { error: Arc<Error> },
477 Failed { error: Arc<Error> },
479 Poisoned { keys: PoisonKeys, error: Arc<Error> },
481}
482
483impl ProcedureState {
484 pub fn failed(error: Arc<Error>) -> ProcedureState {
486 ProcedureState::Failed { error }
487 }
488
489 pub fn prepare_rollback(error: Arc<Error>) -> ProcedureState {
491 ProcedureState::PrepareRollback { error }
492 }
493
494 pub fn rolling_back(error: Arc<Error>) -> ProcedureState {
496 ProcedureState::RollingBack { error }
497 }
498
499 pub fn retrying(error: Arc<Error>) -> ProcedureState {
501 ProcedureState::Retrying { error }
502 }
503
504 pub fn poisoned(keys: PoisonKeys, error: Arc<Error>) -> ProcedureState {
506 ProcedureState::Poisoned { keys, error }
507 }
508
509 pub fn is_running(&self) -> bool {
511 matches!(self, ProcedureState::Running)
512 }
513
514 pub fn is_done(&self) -> bool {
516 matches!(self, ProcedureState::Done { .. })
517 }
518
519 pub fn is_poisoned(&self) -> bool {
521 matches!(self, ProcedureState::Poisoned { .. })
522 }
523
524 pub fn is_failed(&self) -> bool {
526 matches!(self, ProcedureState::Failed { .. })
527 }
528
529 pub fn is_retrying(&self) -> bool {
531 matches!(self, ProcedureState::Retrying { .. })
532 }
533
534 pub fn is_rolling_back(&self) -> bool {
536 matches!(self, ProcedureState::RollingBack { .. })
537 }
538
539 pub fn is_prepare_rollback(&self) -> bool {
541 matches!(self, ProcedureState::PrepareRollback { .. })
542 }
543
544 pub fn error(&self) -> Option<&Arc<Error>> {
546 match self {
547 ProcedureState::Failed { error } => Some(error),
548 ProcedureState::Retrying { error } => Some(error),
549 ProcedureState::RollingBack { error } => Some(error),
550 ProcedureState::Poisoned { error, .. } => Some(error),
551 _ => None,
552 }
553 }
554
555 pub fn as_str_name(&self) -> &str {
557 match self {
558 ProcedureState::Running => "Running",
559 ProcedureState::Done { .. } => "Done",
560 ProcedureState::Retrying { .. } => "Retrying",
561 ProcedureState::Failed { .. } => "Failed",
562 ProcedureState::PrepareRollback { .. } => "PrepareRollback",
563 ProcedureState::RollingBack { .. } => "RollingBack",
564 ProcedureState::Poisoned { .. } => "Poisoned",
565 }
566 }
567}
568
569#[derive(Debug, Clone)]
571pub enum InitProcedureState {
572 Running,
573 RollingBack,
574}
575
576#[async_trait]
579pub trait ProcedureManager: Send + Sync + 'static {
580 fn register_loader(&self, name: &str, loader: BoxedProcedureLoader) -> Result<()>;
582
583 async fn start(&self) -> Result<()>;
589
590 async fn stop(&self) -> Result<()>;
592
593 async fn submit(&self, procedure: ProcedureWithId) -> Result<Watcher>;
597
598 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
602
603 fn procedure_watcher(&self, procedure_id: ProcedureId) -> Option<Watcher>;
605
606 async fn list_procedures(&self) -> Result<Vec<ProcedureInfo>>;
608}
609
610pub type ProcedureManagerRef = Arc<dyn ProcedureManager>;
612
613#[derive(Debug, Clone)]
614pub struct ProcedureInfo {
615 pub id: ProcedureId,
617 pub type_name: String,
619 pub start_time_ms: i64,
621 pub end_time_ms: i64,
623 pub state: ProcedureState,
625 pub lock_keys: Vec<String>,
627}
628
629#[cfg(test)]
630mod tests {
631 use common_error::mock::MockError;
632 use common_error::status_code::StatusCode;
633
634 use super::*;
635
636 #[test]
637 fn test_status() {
638 let status = Status::executing(false);
639 assert!(!status.need_persist());
640
641 let status = Status::executing(true);
642 assert!(status.need_persist());
643
644 let status = Status::executing_with_clean_poisons(false);
645 assert!(status.need_clean_poisons());
646
647 let status = Status::executing_with_clean_poisons(true);
648 assert!(status.need_clean_poisons());
649
650 let status = Status::Suspended {
651 subprocedures: Vec::new(),
652 persist: false,
653 };
654 assert!(!status.need_persist());
655
656 let status = Status::Suspended {
657 subprocedures: Vec::new(),
658 persist: true,
659 };
660 assert!(status.need_persist());
661
662 let status = Status::done();
663 assert!(!status.need_persist());
664 assert!(status.need_clean_poisons());
665 }
666
667 #[test]
668 fn test_lock_key() {
669 let entity = "catalog.schema.my_table";
670 let key = LockKey::single_exclusive(entity);
671 assert_eq!(
672 vec![&StringKey::Exclusive(entity.to_string())],
673 key.keys_to_lock().collect::<Vec<_>>()
674 );
675
676 let key = LockKey::new_exclusive([
677 "b".to_string(),
678 "c".to_string(),
679 "a".to_string(),
680 "c".to_string(),
681 ]);
682 assert_eq!(
683 vec![
684 &StringKey::Exclusive("a".to_string()),
685 &StringKey::Exclusive("b".to_string()),
686 &StringKey::Exclusive("c".to_string())
687 ],
688 key.keys_to_lock().collect::<Vec<_>>()
689 );
690 }
691
692 #[test]
693 fn test_procedure_id() {
694 let id = ProcedureId::random();
695 let uuid_str = id.to_string();
696 assert_eq!(id.0.to_string(), uuid_str);
697
698 let parsed = ProcedureId::parse_str(&uuid_str).unwrap();
699 assert_eq!(id, parsed);
700 let parsed = uuid_str.parse().unwrap();
701 assert_eq!(id, parsed);
702 }
703
704 #[test]
705 fn test_procedure_id_serialization() {
706 let id = ProcedureId::random();
707 let json = serde_json::to_string(&id).unwrap();
708 assert_eq!(format!("\"{id}\""), json);
709
710 let parsed = serde_json::from_str(&json).unwrap();
711 assert_eq!(id, parsed);
712 }
713
714 #[test]
715 fn test_procedure_state() {
716 assert!(ProcedureState::Running.is_running());
717 assert!(ProcedureState::Running.error().is_none());
718 assert!(ProcedureState::Done { output: None }.is_done());
719
720 let state = ProcedureState::failed(Arc::new(Error::external(MockError::new(
721 StatusCode::Unexpected,
722 ))));
723 assert!(state.is_failed());
724 let _ = state.error().unwrap();
725 }
726}