common_procedure_test/
lib.rsuse std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use common_procedure::{
Context, ContextProvider, Output, Procedure, ProcedureId, ProcedureState, ProcedureWithId,
Result, Status,
};
#[derive(Default)]
pub struct MockContextProvider {
states: HashMap<ProcedureId, ProcedureState>,
}
impl MockContextProvider {
pub fn new(states: HashMap<ProcedureId, ProcedureState>) -> MockContextProvider {
MockContextProvider { states }
}
}
#[async_trait]
impl ContextProvider for MockContextProvider {
async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>> {
Ok(self.states.get(&procedure_id).cloned())
}
}
pub async fn execute_procedure_until_done(procedure: &mut dyn Procedure) -> Option<Output> {
let ctx = Context {
procedure_id: ProcedureId::random(),
provider: Arc::new(MockContextProvider::default()),
};
loop {
match procedure.execute(&ctx).await.unwrap() {
Status::Executing { .. } => (),
Status::Suspended { subprocedures, .. } => assert!(
subprocedures.is_empty(),
"Executing subprocedure is unsupported"
),
Status::Done { output } => return output,
}
}
}
pub async fn execute_procedure_once(
procedure_id: ProcedureId,
provider: MockContextProvider,
procedure: &mut dyn Procedure,
) -> bool {
let ctx = Context {
procedure_id,
provider: Arc::new(provider),
};
match procedure.execute(&ctx).await.unwrap() {
Status::Executing { .. } => false,
Status::Suspended { subprocedures, .. } => {
assert!(
subprocedures.is_empty(),
"Executing subprocedure is unsupported"
);
false
}
Status::Done { .. } => true,
}
}
pub async fn execute_until_suspended_or_done(
procedure_id: ProcedureId,
provider: MockContextProvider,
procedure: &mut dyn Procedure,
) -> Option<Vec<ProcedureWithId>> {
let ctx = Context {
procedure_id,
provider: Arc::new(provider),
};
loop {
match procedure.execute(&ctx).await.unwrap() {
Status::Executing { .. } => (),
Status::Suspended { subprocedures, .. } => return Some(subprocedures),
Status::Done { .. } => break,
}
}
None
}
pub fn new_test_procedure_context() -> Context {
Context {
procedure_id: ProcedureId::random(),
provider: Arc::new(MockContextProvider::default()),
}
}
pub async fn execute_procedure_until<P: Procedure>(procedure: &mut P, until: impl Fn(&P) -> bool) {
let mut reached = false;
let context = new_test_procedure_context();
while !matches!(
procedure.execute(&context).await.unwrap(),
Status::Done { .. }
) {
if until(procedure) {
reached = true;
break;
}
}
assert!(
reached,
"procedure '{}' did not reach the expected state",
procedure.type_name()
);
}