common_procedure_test/
lib.rs1use std::collections::HashMap;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use common_procedure::local::{DynamicKeyLockGuard, acquire_dynamic_key_lock};
22use common_procedure::rwlock::KeyRwLock;
23use common_procedure::store::poison_store::PoisonStore;
24use common_procedure::test_util::InMemoryPoisonStore;
25use common_procedure::{
26 Context, ContextProvider, Output, PoisonKey, Procedure, ProcedureId, ProcedureState,
27 ProcedureWithId, Result, Status, StringKey,
28};
29use tokio::sync::watch::Receiver;
30
31#[derive(Default)]
33pub struct MockContextProvider {
34 states: HashMap<ProcedureId, ProcedureState>,
35 poison_manager: InMemoryPoisonStore,
36 dynamic_key_lock: Arc<KeyRwLock<String>>,
37}
38
39impl MockContextProvider {
40 pub fn new(states: HashMap<ProcedureId, ProcedureState>) -> MockContextProvider {
42 MockContextProvider {
43 states,
44 poison_manager: InMemoryPoisonStore::default(),
45 dynamic_key_lock: Arc::new(KeyRwLock::new()),
46 }
47 }
48
49 pub fn poison_manager(&self) -> &InMemoryPoisonStore {
51 &self.poison_manager
52 }
53}
54
55#[async_trait]
56impl ContextProvider for MockContextProvider {
57 async fn procedure_state(&self, procedure_id: ProcedureId) -> Result<Option<ProcedureState>> {
58 Ok(self.states.get(&procedure_id).cloned())
59 }
60
61 async fn procedure_state_receiver(
62 &self,
63 _procedure_id: ProcedureId,
64 ) -> Result<Option<Receiver<ProcedureState>>> {
65 Ok(None)
66 }
67
68 async fn try_put_poison(&self, key: &PoisonKey, procedure_id: ProcedureId) -> Result<()> {
69 self.poison_manager
70 .try_put_poison(key.to_string(), procedure_id.to_string())
71 .await
72 }
73
74 async fn acquire_lock(&self, key: &StringKey) -> DynamicKeyLockGuard {
75 acquire_dynamic_key_lock(&self.dynamic_key_lock, key).await
76 }
77}
78
79pub async fn execute_procedure_until_done(procedure: &mut dyn Procedure) -> Option<Output> {
84 let ctx = Context {
85 procedure_id: ProcedureId::random(),
86 provider: Arc::new(MockContextProvider::default()),
87 };
88
89 loop {
90 match procedure.execute(&ctx).await.unwrap() {
91 Status::Executing { .. } => (),
92 Status::Suspended { subprocedures, .. } => assert!(
93 subprocedures.is_empty(),
94 "Executing subprocedure is unsupported"
95 ),
96 Status::Done { output } => return output,
97 Status::Poisoned { .. } => return None,
98 }
99 }
100}
101
102pub async fn execute_procedure_once(
106 procedure_id: ProcedureId,
107 provider: MockContextProvider,
108 procedure: &mut dyn Procedure,
109) -> bool {
110 let ctx = Context {
111 procedure_id,
112 provider: Arc::new(provider),
113 };
114
115 match procedure.execute(&ctx).await.unwrap() {
116 Status::Executing { .. } => false,
117 Status::Suspended { subprocedures, .. } => {
118 assert!(
119 subprocedures.is_empty(),
120 "Executing subprocedure is unsupported"
121 );
122 false
123 }
124 Status::Done { .. } => true,
125 Status::Poisoned { .. } => false,
126 }
127}
128
129pub async fn execute_until_suspended_or_done(
133 procedure_id: ProcedureId,
134 provider: MockContextProvider,
135 procedure: &mut dyn Procedure,
136) -> Option<Vec<ProcedureWithId>> {
137 let ctx = Context {
138 procedure_id,
139 provider: Arc::new(provider),
140 };
141
142 loop {
143 match procedure.execute(&ctx).await.unwrap() {
144 Status::Executing { .. } => (),
145 Status::Suspended { subprocedures, .. } => return Some(subprocedures),
146 Status::Done { .. } => break,
147 Status::Poisoned { .. } => unreachable!(),
148 }
149 }
150
151 None
152}
153
154pub fn new_test_procedure_context() -> Context {
155 Context {
156 procedure_id: ProcedureId::random(),
157 provider: Arc::new(MockContextProvider::default()),
158 }
159}
160
161pub async fn execute_procedure_until<P: Procedure>(procedure: &mut P, until: impl Fn(&P) -> bool) {
162 let mut reached = false;
163 let context = new_test_procedure_context();
164 while !matches!(
165 procedure.execute(&context).await.unwrap(),
166 Status::Done { .. }
167 ) {
168 if until(procedure) {
169 reached = true;
170 break;
171 }
172 }
173 assert!(
174 reached,
175 "procedure '{}' did not reach the expected state",
176 procedure.type_name()
177 );
178}