use std::any::Any;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use smallvec::{smallvec, SmallVec};
use snafu::{ResultExt, Snafu};
use uuid::Uuid;
use crate::error::{self, Error, Result};
use crate::watcher::Watcher;
pub type Output = Arc<dyn Any + Send + Sync>;
#[derive(Debug)]
pub enum Status {
Executing {
persist: bool,
},
Suspended {
subprocedures: Vec<ProcedureWithId>,
persist: bool,
},
Done { output: Option<Output> },
}
impl Status {
pub fn executing(persist: bool) -> Status {
Status::Executing { persist }
}
pub fn done() -> Status {
Status::Done { output: None }
}
#[cfg(any(test, feature = "testing"))]
pub fn downcast_output_ref<T: 'static>(&self) -> Option<&T> {
if let Status::Done { output } = self {
output
.as_ref()
.expect("Try to downcast the output of Status::Done, but the output is None")
.downcast_ref()
} else {
panic!("Expected the Status::Done, but got: {:?}", self)
}
}
pub fn done_with_output<T: Any + Send + Sync>(output: T) -> Status {
Status::Done {
output: Some(Arc::new(output)),
}
}
pub fn is_done(&self) -> bool {
matches!(self, Status::Done { .. })
}
pub fn need_persist(&self) -> bool {
match self {
Status::Executing { persist } | Status::Suspended { persist, .. } => *persist,
Status::Done { .. } => false,
}
}
}
#[async_trait]
pub trait ContextProvider: Send + Sync {
async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
}
pub type ContextProviderRef = Arc<dyn ContextProvider>;
#[derive(Clone)]
pub struct Context {
pub procedure_id: ProcedureId,
pub provider: ContextProviderRef,
}
#[async_trait]
pub trait Procedure: Send {
fn type_name(&self) -> &str;
async fn execute(&mut self, ctx: &Context) -> Result<Status>;
async fn rollback(&mut self, _: &Context) -> Result<()> {
error::RollbackNotSupportedSnafu {}.fail()
}
fn rollback_supported(&self) -> bool {
false
}
fn dump(&self) -> Result<String>;
fn recover(&mut self) -> Result<()> {
Ok(())
}
fn lock_key(&self) -> LockKey;
}
#[async_trait]
impl<T: Procedure + ?Sized> Procedure for Box<T> {
fn type_name(&self) -> &str {
(**self).type_name()
}
async fn execute(&mut self, ctx: &Context) -> Result<Status> {
(**self).execute(ctx).await
}
async fn rollback(&mut self, ctx: &Context) -> Result<()> {
(**self).rollback(ctx).await
}
fn rollback_supported(&self) -> bool {
(**self).rollback_supported()
}
fn dump(&self) -> Result<String> {
(**self).dump()
}
fn lock_key(&self) -> LockKey {
(**self).lock_key()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum StringKey {
Share(String),
Exclusive(String),
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct LockKey(SmallVec<[StringKey; 2]>);
impl StringKey {
pub fn into_string(self) -> String {
match self {
StringKey::Share(s) => s,
StringKey::Exclusive(s) => s,
}
}
pub fn as_string(&self) -> &String {
match self {
StringKey::Share(s) => s,
StringKey::Exclusive(s) => s,
}
}
}
impl LockKey {
pub fn single(key: impl Into<StringKey>) -> LockKey {
LockKey(smallvec![key.into()])
}
pub fn single_exclusive(key: impl Into<String>) -> LockKey {
LockKey(smallvec![StringKey::Exclusive(key.into())])
}
pub fn new(iter: impl IntoIterator<Item = StringKey>) -> LockKey {
let mut vec: SmallVec<_> = iter.into_iter().collect();
vec.sort();
vec.dedup();
LockKey(vec)
}
pub fn new_exclusive(iter: impl IntoIterator<Item = String>) -> LockKey {
Self::new(iter.into_iter().map(StringKey::Exclusive))
}
pub fn keys_to_lock(&self) -> impl Iterator<Item = &StringKey> {
self.0.iter()
}
pub fn get_keys(&self) -> Vec<String> {
self.0.iter().map(|key| format!("{:?}", key)).collect()
}
}
pub type BoxedProcedure = Box<dyn Procedure>;
pub struct ProcedureWithId {
pub id: ProcedureId,
pub procedure: BoxedProcedure,
}
impl ProcedureWithId {
pub fn with_random_id(procedure: BoxedProcedure) -> ProcedureWithId {
ProcedureWithId {
id: ProcedureId::random(),
procedure,
}
}
}
impl fmt::Debug for ProcedureWithId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}-{}", self.procedure.type_name(), self.id)
}
}
#[derive(Debug, Snafu)]
pub struct ParseIdError {
source: uuid::Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ProcedureId(Uuid);
impl ProcedureId {
pub fn random() -> ProcedureId {
ProcedureId(Uuid::new_v4())
}
pub fn parse_str(input: &str) -> std::result::Result<ProcedureId, ParseIdError> {
Uuid::parse_str(input)
.map(ProcedureId)
.context(ParseIdSnafu)
}
}
impl fmt::Display for ProcedureId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromStr for ProcedureId {
type Err = ParseIdError;
fn from_str(s: &str) -> std::result::Result<ProcedureId, ParseIdError> {
ProcedureId::parse_str(s)
}
}
pub type BoxedProcedureLoader = Box<dyn Fn(&str) -> Result<BoxedProcedure> + Send>;
#[derive(Debug, Default, Clone)]
pub enum ProcedureState {
#[default]
Running,
Done { output: Option<Output> },
Retrying { error: Arc<Error> },
PrepareRollback { error: Arc<Error> },
RollingBack { error: Arc<Error> },
Failed { error: Arc<Error> },
}
impl ProcedureState {
pub fn failed(error: Arc<Error>) -> ProcedureState {
ProcedureState::Failed { error }
}
pub fn prepare_rollback(error: Arc<Error>) -> ProcedureState {
ProcedureState::PrepareRollback { error }
}
pub fn rolling_back(error: Arc<Error>) -> ProcedureState {
ProcedureState::RollingBack { error }
}
pub fn retrying(error: Arc<Error>) -> ProcedureState {
ProcedureState::Retrying { error }
}
pub fn is_running(&self) -> bool {
matches!(self, ProcedureState::Running)
}
pub fn is_done(&self) -> bool {
matches!(self, ProcedureState::Done { .. })
}
pub fn is_failed(&self) -> bool {
matches!(self, ProcedureState::Failed { .. })
}
pub fn is_retrying(&self) -> bool {
matches!(self, ProcedureState::Retrying { .. })
}
pub fn is_rolling_back(&self) -> bool {
matches!(self, ProcedureState::RollingBack { .. })
}
pub fn is_prepare_rollback(&self) -> bool {
matches!(self, ProcedureState::PrepareRollback { .. })
}
pub fn error(&self) -> Option<&Arc<Error>> {
match self {
ProcedureState::Failed { error } => Some(error),
ProcedureState::Retrying { error } => Some(error),
ProcedureState::RollingBack { error } => Some(error),
_ => None,
}
}
pub fn as_str_name(&self) -> &str {
match self {
ProcedureState::Running => "Running",
ProcedureState::Done { .. } => "Done",
ProcedureState::Retrying { .. } => "Retrying",
ProcedureState::Failed { .. } => "Failed",
ProcedureState::PrepareRollback { .. } => "PrepareRollback",
ProcedureState::RollingBack { .. } => "RollingBack",
}
}
}
#[derive(Debug, Clone)]
pub enum InitProcedureState {
Running,
RollingBack,
}
#[async_trait]
pub trait ProcedureManager: Send + Sync + 'static {
fn register_loader(&self, name: &str, loader: BoxedProcedureLoader) -> Result<()>;
async fn start(&self) -> Result<()>;
async fn stop(&self) -> Result<()>;
async fn submit(&self, procedure: ProcedureWithId) -> Result<Watcher>;
async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>>;
fn procedure_watcher(&self, procedure_id: ProcedureId) -> Option<Watcher>;
async fn list_procedures(&self) -> Result<Vec<ProcedureInfo>>;
}
pub type ProcedureManagerRef = Arc<dyn ProcedureManager>;
#[derive(Debug, Clone)]
pub struct ProcedureInfo {
pub id: ProcedureId,
pub type_name: String,
pub start_time_ms: i64,
pub end_time_ms: i64,
pub state: ProcedureState,
pub lock_keys: Vec<String>,
}
#[cfg(test)]
mod tests {
use common_error::mock::MockError;
use common_error::status_code::StatusCode;
use super::*;
#[test]
fn test_status() {
let status = Status::Executing { persist: false };
assert!(!status.need_persist());
let status = Status::Executing { persist: true };
assert!(status.need_persist());
let status = Status::Suspended {
subprocedures: Vec::new(),
persist: false,
};
assert!(!status.need_persist());
let status = Status::Suspended {
subprocedures: Vec::new(),
persist: true,
};
assert!(status.need_persist());
let status = Status::done();
assert!(!status.need_persist());
}
#[test]
fn test_lock_key() {
let entity = "catalog.schema.my_table";
let key = LockKey::single_exclusive(entity);
assert_eq!(
vec![&StringKey::Exclusive(entity.to_string())],
key.keys_to_lock().collect::<Vec<_>>()
);
let key = LockKey::new_exclusive([
"b".to_string(),
"c".to_string(),
"a".to_string(),
"c".to_string(),
]);
assert_eq!(
vec![
&StringKey::Exclusive("a".to_string()),
&StringKey::Exclusive("b".to_string()),
&StringKey::Exclusive("c".to_string())
],
key.keys_to_lock().collect::<Vec<_>>()
);
}
#[test]
fn test_procedure_id() {
let id = ProcedureId::random();
let uuid_str = id.to_string();
assert_eq!(id.0.to_string(), uuid_str);
let parsed = ProcedureId::parse_str(&uuid_str).unwrap();
assert_eq!(id, parsed);
let parsed = uuid_str.parse().unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_procedure_id_serialization() {
let id = ProcedureId::random();
let json = serde_json::to_string(&id).unwrap();
assert_eq!(format!("\"{id}\""), json);
let parsed = serde_json::from_str(&json).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_procedure_state() {
assert!(ProcedureState::Running.is_running());
assert!(ProcedureState::Running.error().is_none());
assert!(ProcedureState::Done { output: None }.is_done());
let state = ProcedureState::failed(Arc::new(Error::external(MockError::new(
StatusCode::Unexpected,
))));
assert!(state.is_failed());
let _ = state.error().unwrap();
}
}