Skip to main content

meta_srv/procedure/repartition/
group.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub(crate) mod apply_staging_manifest;
16pub(crate) mod enter_staging_region;
17pub(crate) mod remap_manifest;
18pub(crate) mod repartition_end;
19pub(crate) mod repartition_start;
20pub(crate) mod sync_region;
21pub(crate) mod update_metadata;
22pub(crate) mod utils;
23
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt::{Debug, Display};
27use std::time::{Duration, Instant};
28
29use common_error::ext::BoxedError;
30use common_meta::cache_invalidator::CacheInvalidatorRef;
31use common_meta::ddl::DdlContext;
32use common_meta::instruction::CacheIdent;
33use common_meta::key::datanode_table::{DatanodeTableValue, RegionInfo};
34use common_meta::key::table_route::TableRouteValue;
35use common_meta::key::{DeserializedValueWithBytes, TableMetadataManagerRef};
36use common_meta::lock_key::{CatalogLock, RegionLock, SchemaLock};
37use common_meta::peer::Peer;
38use common_meta::rpc::router::RegionRoute;
39use common_procedure::error::{FromJsonSnafu, ToJsonSnafu};
40use common_procedure::{
41    Context as ProcedureContext, Error as ProcedureError, LockKey, Procedure,
42    Result as ProcedureResult, Status, StringKey, UserMetadata,
43};
44use common_telemetry::{error, info, warn};
45use serde::{Deserialize, Serialize};
46use snafu::{OptionExt, ResultExt};
47use store_api::storage::{RegionId, TableId};
48use uuid::Uuid;
49
50use crate::error::{self, Result};
51use crate::procedure::repartition::group::apply_staging_manifest::ApplyStagingManifest;
52use crate::procedure::repartition::group::enter_staging_region::EnterStagingRegion;
53use crate::procedure::repartition::group::remap_manifest::RemapManifest;
54use crate::procedure::repartition::group::repartition_start::RepartitionStart;
55use crate::procedure::repartition::group::update_metadata::UpdateMetadata;
56use crate::procedure::repartition::plan::RegionDescriptor;
57use crate::procedure::repartition::utils::get_datanode_table_value;
58use crate::procedure::repartition::{self};
59use crate::service::mailbox::MailboxRef;
60
61#[derive(Debug, Clone, Default)]
62pub struct Metrics {
63    /// Elapsed time of flushing pending deallocate regions.
64    flush_pending_deallocate_regions_elapsed: Duration,
65    /// Elapsed time of entering staging region.
66    enter_staging_region_elapsed: Duration,
67    /// Elapsed time of applying staging manifest.
68    apply_staging_manifest_elapsed: Duration,
69    /// Elapsed time of remapping manifest.
70    remap_manifest_elapsed: Duration,
71    /// Elapsed time of updating metadata.
72    update_metadata_elapsed: Duration,
73}
74
75impl Display for Metrics {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        let total = self.flush_pending_deallocate_regions_elapsed
78            + self.enter_staging_region_elapsed
79            + self.apply_staging_manifest_elapsed
80            + self.remap_manifest_elapsed
81            + self.update_metadata_elapsed;
82        write!(f, "total: {:?}", total)?;
83        let mut parts = Vec::with_capacity(5);
84        if self.flush_pending_deallocate_regions_elapsed > Duration::ZERO {
85            parts.push(format!(
86                "flush_pending_deallocate_regions_elapsed: {:?}",
87                self.flush_pending_deallocate_regions_elapsed
88            ));
89        }
90        if self.enter_staging_region_elapsed > Duration::ZERO {
91            parts.push(format!(
92                "enter_staging_region_elapsed: {:?}",
93                self.enter_staging_region_elapsed
94            ));
95        }
96        if self.apply_staging_manifest_elapsed > Duration::ZERO {
97            parts.push(format!(
98                "apply_staging_manifest_elapsed: {:?}",
99                self.apply_staging_manifest_elapsed
100            ));
101        }
102        if self.remap_manifest_elapsed > Duration::ZERO {
103            parts.push(format!(
104                "remap_manifest_elapsed: {:?}",
105                self.remap_manifest_elapsed
106            ));
107        }
108        if self.update_metadata_elapsed > Duration::ZERO {
109            parts.push(format!(
110                "update_metadata_elapsed: {:?}",
111                self.update_metadata_elapsed
112            ));
113        }
114
115        if !parts.is_empty() {
116            write!(f, ", {}", parts.join(", "))?;
117        }
118        Ok(())
119    }
120}
121
122impl Metrics {
123    /// Updates the elapsed time of entering staging region.
124    pub fn update_enter_staging_region_elapsed(&mut self, elapsed: Duration) {
125        self.enter_staging_region_elapsed += elapsed;
126    }
127
128    pub fn update_flush_pending_deallocate_regions_elapsed(&mut self, elapsed: Duration) {
129        self.flush_pending_deallocate_regions_elapsed += elapsed;
130    }
131
132    /// Updates the elapsed time of applying staging manifest.
133    pub fn update_apply_staging_manifest_elapsed(&mut self, elapsed: Duration) {
134        self.apply_staging_manifest_elapsed += elapsed;
135    }
136
137    /// Updates the elapsed time of remapping manifest.
138    pub fn update_remap_manifest_elapsed(&mut self, elapsed: Duration) {
139        self.remap_manifest_elapsed += elapsed;
140    }
141
142    /// Updates the elapsed time of updating metadata.
143    pub fn update_update_metadata_elapsed(&mut self, elapsed: Duration) {
144        self.update_metadata_elapsed += elapsed;
145    }
146}
147
148pub type GroupId = Uuid;
149
150pub struct RepartitionGroupProcedure {
151    state: Box<dyn State>,
152    context: Context,
153}
154
155#[derive(Debug, Serialize)]
156struct RepartitionGroupData<'a> {
157    persistent_ctx: &'a PersistentContext,
158    state: &'a dyn State,
159}
160
161#[derive(Debug, Deserialize)]
162struct RepartitionGroupDataOwned {
163    persistent_ctx: PersistentContext,
164    state: Box<dyn State>,
165}
166
167impl RepartitionGroupProcedure {
168    pub(crate) const TYPE_NAME: &'static str = "metasrv-procedure::RepartitionGroup";
169
170    pub fn new(persistent_context: PersistentContext, context: &repartition::Context) -> Self {
171        let state = Box::new(RepartitionStart);
172
173        Self {
174            state,
175            context: Context {
176                persistent_ctx: persistent_context,
177                cache_invalidator: context.cache_invalidator.clone(),
178                table_metadata_manager: context.table_metadata_manager.clone(),
179                mailbox: context.mailbox.clone(),
180                server_addr: context.server_addr.clone(),
181                start_time: Instant::now(),
182                volatile_ctx: VolatileContext::default(),
183            },
184        }
185    }
186
187    pub fn from_json<F>(json: &str, ctx_factory: F) -> ProcedureResult<Self>
188    where
189        F: FnOnce(PersistentContext) -> Context,
190    {
191        let RepartitionGroupDataOwned {
192            state,
193            persistent_ctx,
194        } = serde_json::from_str(json).context(FromJsonSnafu)?;
195        let context = ctx_factory(persistent_ctx);
196
197        Ok(Self { state, context })
198    }
199
200    async fn rollback_inner(&mut self, procedure_ctx: &ProcedureContext) -> Result<()> {
201        if !self.should_rollback_metadata() {
202            return Ok(());
203        }
204
205        let table_lock =
206            common_meta::lock_key::TableLock::Write(self.context.persistent_ctx.table_id).into();
207        let _guard = procedure_ctx.provider.acquire_lock(&table_lock).await;
208        UpdateMetadata::RollbackStaging
209            .rollback_staging_regions(&mut self.context)
210            .await?;
211
212        if let Err(err) = self.context.invalidate_table_cache().await {
213            warn!(
214                err;
215                "Failed to broadcast the invalidate table cache message during repartition group rollback"
216            );
217        }
218
219        Ok(())
220    }
221
222    /// Returns whether group rollback should revert staging metadata.
223    ///
224    /// This uses an "after metadata apply, before exit staging" semantic.
225    /// Once execution reaches `UpdateMetadata::ApplyStaging` or any later staging state,
226    /// rollback must restore table-route metadata back to the pre-apply view.
227    ///
228    /// State flow:
229    /// `RepartitionStart -> SyncRegion -> UpdateMetadata::ApplyStaging -> EnterStagingRegion`
230    /// `                 -> RemapManifest -> ApplyStagingManifest -> UpdateMetadata::ExitStaging -> RepartitionEnd`
231    /// `                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^`
232    /// `                               rollback staging metadata`
233    ///
234    /// Notes:
235    /// - `RepartitionStart` / `SyncRegion`: no-op, metadata has not been staged yet.
236    /// - `UpdateMetadata::ApplyStaging` / `EnterStagingRegion` / `RemapManifest` /
237    ///   `ApplyStagingManifest` / `UpdateMetadata::RollbackStaging`: rollback-active.
238    /// - `UpdateMetadata::ExitStaging` / `RepartitionEnd`: excluded, because metadata has
239    ///   already moved into the post-commit exit path.
240    fn should_rollback_metadata(&self) -> bool {
241        self.state.as_any().is::<EnterStagingRegion>()
242            || self.state.as_any().is::<RemapManifest>()
243            || self.state.as_any().is::<ApplyStagingManifest>()
244            || self
245                .state
246                .as_any()
247                .downcast_ref::<UpdateMetadata>()
248                .is_some_and(|state| {
249                    matches!(
250                        state,
251                        UpdateMetadata::ApplyStaging | UpdateMetadata::RollbackStaging
252                    )
253                })
254    }
255}
256
257#[async_trait::async_trait]
258impl Procedure for RepartitionGroupProcedure {
259    fn type_name(&self) -> &str {
260        Self::TYPE_NAME
261    }
262
263    async fn rollback(&mut self, ctx: &ProcedureContext) -> ProcedureResult<()> {
264        self.rollback_inner(ctx)
265            .await
266            .map_err(ProcedureError::external)
267    }
268
269    #[tracing::instrument(skip_all, fields(
270        state = %self.state.name(),
271        table_id = self.context.persistent_ctx.table_id,
272        group_id = %self.context.persistent_ctx.group_id,
273    ))]
274    async fn execute(&mut self, _ctx: &ProcedureContext) -> ProcedureResult<Status> {
275        let state = &mut self.state;
276        let state_name = state.name();
277        // Log state transition
278        common_telemetry::info!(
279            "Repartition group procedure executing state: {}, group id: {}, table id: {}",
280            state_name,
281            self.context.persistent_ctx.group_id,
282            self.context.persistent_ctx.table_id
283        );
284
285        match state.next(&mut self.context, _ctx).await {
286            Ok((next, status)) => {
287                *state = next;
288                Ok(status)
289            }
290            Err(e) => {
291                if e.is_retryable() {
292                    Err(ProcedureError::retry_later(e))
293                } else {
294                    error!(
295                        e;
296                        "Repartition group procedure failed, group id: {}, table id: {}",
297                        self.context.persistent_ctx.group_id,
298                        self.context.persistent_ctx.table_id,
299                    );
300                    Err(ProcedureError::external(e))
301                }
302            }
303        }
304    }
305
306    fn rollback_supported(&self) -> bool {
307        true
308    }
309
310    fn dump(&self) -> ProcedureResult<String> {
311        let data = RepartitionGroupData {
312            persistent_ctx: &self.context.persistent_ctx,
313            state: self.state.as_ref(),
314        };
315        serde_json::to_string(&data).context(ToJsonSnafu)
316    }
317
318    fn lock_key(&self) -> LockKey {
319        LockKey::new(self.context.persistent_ctx.lock_key())
320    }
321
322    fn user_metadata(&self) -> Option<UserMetadata> {
323        // TODO(weny): support user metadata.
324        None
325    }
326}
327
328pub struct Context {
329    pub persistent_ctx: PersistentContext,
330
331    pub cache_invalidator: CacheInvalidatorRef,
332
333    pub table_metadata_manager: TableMetadataManagerRef,
334
335    pub mailbox: MailboxRef,
336
337    pub server_addr: String,
338
339    pub start_time: Instant,
340
341    pub volatile_ctx: VolatileContext,
342}
343
344#[derive(Debug, Clone, Default)]
345pub struct VolatileContext {
346    pub metrics: Metrics,
347}
348
349impl Context {
350    pub fn new(
351        ddl_ctx: &DdlContext,
352        mailbox: MailboxRef,
353        server_addr: String,
354        persistent_ctx: PersistentContext,
355    ) -> Self {
356        Self {
357            persistent_ctx,
358            cache_invalidator: ddl_ctx.cache_invalidator.clone(),
359            table_metadata_manager: ddl_ctx.table_metadata_manager.clone(),
360            mailbox,
361            server_addr,
362            start_time: Instant::now(),
363            volatile_ctx: VolatileContext::default(),
364        }
365    }
366}
367
368/// The result of the group preparation phase, containing validated region routes.
369#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
370pub struct GroupPrepareResult {
371    /// The validated source region routes.
372    pub source_routes: Vec<RegionRoute>,
373    /// Validated target region routes used for metadata rollback (logical rollback).
374    pub target_routes: Vec<RegionRoute>,
375    /// The primary source region id (first source region), used for retrieving region options.
376    pub central_region: RegionId,
377    /// The peer where the primary source region is located.
378    pub central_region_datanode: Peer,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
382pub struct PersistentContext {
383    pub group_id: GroupId,
384    /// The table id of the repartition group.
385    pub table_id: TableId,
386    /// The catalog name of the repartition group.
387    pub catalog_name: String,
388    /// The schema name of the repartition group.
389    pub schema_name: String,
390    /// The source regions of the repartition group.
391    pub sources: Vec<RegionDescriptor>,
392    /// The target regions of the repartition group.
393    pub targets: Vec<RegionDescriptor>,
394    /// For each `source region`, the corresponding
395    /// `target regions` that overlap with it.
396    pub region_mapping: HashMap<RegionId, Vec<RegionId>>,
397    /// The result of group prepare.
398    /// The value will be set in [RepartitionStart](crate::procedure::repartition::group::repartition_start::RepartitionStart) state.
399    pub group_prepare_result: Option<GroupPrepareResult>,
400    /// The staging manifest paths of the repartition group.
401    /// The value will be set in [RemapManifest](crate::procedure::repartition::group::remap_manifest::RemapManifest) state.
402    pub staging_manifest_paths: HashMap<RegionId, String>,
403    /// Whether sync region is needed for this group.
404    pub sync_region: bool,
405    /// The region ids of the newly allocated regions.
406    pub allocated_region_ids: Vec<RegionId>,
407    /// The region ids of the regions that are pending deallocation.
408    pub pending_deallocate_region_ids: Vec<RegionId>,
409    /// The timeout for repartition operations.
410    #[serde(with = "humantime_serde")]
411    pub timeout: Duration,
412}
413
414impl PersistentContext {
415    #[allow(clippy::too_many_arguments)]
416    pub fn new(
417        group_id: GroupId,
418        table_id: TableId,
419        catalog_name: String,
420        schema_name: String,
421        sources: Vec<RegionDescriptor>,
422        targets: Vec<RegionDescriptor>,
423        region_mapping: HashMap<RegionId, Vec<RegionId>>,
424        sync_region: bool,
425        allocated_region_ids: Vec<RegionId>,
426        pending_deallocate_region_ids: Vec<RegionId>,
427        timeout: Duration,
428    ) -> Self {
429        Self {
430            group_id,
431            table_id,
432            catalog_name,
433            schema_name,
434            sources,
435            targets,
436            region_mapping,
437            group_prepare_result: None,
438            staging_manifest_paths: HashMap::new(),
439            sync_region,
440            allocated_region_ids,
441            pending_deallocate_region_ids,
442            timeout,
443        }
444    }
445
446    pub fn lock_key(&self) -> Vec<StringKey> {
447        let mut lock_keys = Vec::with_capacity(2 + self.sources.len());
448        lock_keys.extend([
449            CatalogLock::Read(&self.catalog_name).into(),
450            SchemaLock::read(&self.catalog_name, &self.schema_name).into(),
451        ]);
452        for source in &self.sources {
453            lock_keys.push(RegionLock::Write(source.region_id).into());
454        }
455        lock_keys
456    }
457}
458
459impl Context {
460    /// Retrieves the table route value for the given table id.
461    ///
462    /// Retry:
463    /// - Failed to retrieve the metadata of table.
464    ///
465    /// Abort:
466    /// - Table route not found.
467    pub async fn get_table_route_value(
468        &self,
469    ) -> Result<DeserializedValueWithBytes<TableRouteValue>> {
470        let table_id = self.persistent_ctx.table_id;
471        let group_id = self.persistent_ctx.group_id;
472        let table_route_value = self
473            .table_metadata_manager
474            .table_route_manager()
475            .table_route_storage()
476            .get_with_raw_bytes(table_id)
477            .await
478            .map_err(BoxedError::new)
479            .with_context(|_| error::RetryLaterWithSourceSnafu {
480                reason: format!(
481                    "Failed to get table route for table: {}, repartition group: {}",
482                    table_id, group_id
483                ),
484            })?
485            .context(error::TableRouteNotFoundSnafu { table_id })?;
486
487        Ok(table_route_value)
488    }
489
490    /// Returns the `datanode_table_value`
491    ///
492    /// Retry:
493    /// - Failed to retrieve the metadata of datanode table.
494    pub async fn get_datanode_table_value(
495        &self,
496        table_id: TableId,
497        datanode_id: u64,
498    ) -> Result<DatanodeTableValue> {
499        get_datanode_table_value(&self.table_metadata_manager, table_id, datanode_id).await
500    }
501
502    /// Broadcasts the invalidate table cache message.
503    pub async fn invalidate_table_cache(&self) -> Result<()> {
504        let table_id = self.persistent_ctx.table_id;
505        let group_id = self.persistent_ctx.group_id;
506        let subject = format!(
507            "Invalidate table cache for repartition table, group: {}, table: {}",
508            group_id, table_id,
509        );
510        let ctx = common_meta::cache_invalidator::Context {
511            subject: Some(subject),
512        };
513        let _ = self
514            .cache_invalidator
515            .invalidate(&ctx, &[CacheIdent::TableId(table_id)])
516            .await;
517        Ok(())
518    }
519
520    /// Updates the table route.
521    ///
522    /// Retry:
523    /// - Failed to retrieve the metadata of datanode table.
524    ///
525    /// Abort:
526    /// - Table route not found.
527    /// - Failed to update the table route.
528    pub async fn update_table_route(
529        &self,
530        current_table_route_value: &DeserializedValueWithBytes<TableRouteValue>,
531        new_region_routes: Vec<RegionRoute>,
532    ) -> Result<()> {
533        let table_id = self.persistent_ctx.table_id;
534        let group_id = self.persistent_ctx.group_id;
535        // Safety: prepare result is set in [RepartitionStart] state.
536        let prepare_result = self.persistent_ctx.group_prepare_result.as_ref().unwrap();
537        let central_region_datanode_table_value = self
538            .get_datanode_table_value(table_id, prepare_result.central_region_datanode.id)
539            .await?;
540        let RegionInfo {
541            region_options,
542            region_wal_options,
543            ..
544        } = &central_region_datanode_table_value.region_info;
545
546        info!(
547            "Updating table route for table: {}, group_id: {}, new region routes: {:?}",
548            table_id, group_id, new_region_routes
549        );
550        self.table_metadata_manager
551            .update_table_route(
552                table_id,
553                central_region_datanode_table_value.region_info.clone(),
554                current_table_route_value,
555                new_region_routes,
556                region_options,
557                region_wal_options,
558            )
559            .await
560            .context(error::TableMetadataManagerSnafu)
561    }
562
563    /// Updates the table repart mapping.
564    pub async fn update_table_repart_mapping(&self) -> Result<()> {
565        info!(
566            "Updating table repart mapping for table: {}, group_id: {}, region mapping: {:?}",
567            self.persistent_ctx.table_id,
568            self.persistent_ctx.group_id,
569            self.persistent_ctx.region_mapping
570        );
571
572        self.table_metadata_manager
573            .table_repart_manager()
574            .update_mappings(
575                self.persistent_ctx.table_id,
576                &self.persistent_ctx.region_mapping,
577            )
578            .await
579            .context(error::TableMetadataManagerSnafu)
580    }
581
582    /// Returns the next operation timeout.
583    ///
584    /// If the next operation timeout is not set, it will return `None`.
585    pub fn next_operation_timeout(&self) -> Option<Duration> {
586        self.persistent_ctx
587            .timeout
588            .checked_sub(self.start_time.elapsed())
589    }
590
591    /// Updates the elapsed time of entering staging region.
592    pub fn update_enter_staging_region_elapsed(&mut self, elapsed: Duration) {
593        self.volatile_ctx
594            .metrics
595            .update_enter_staging_region_elapsed(elapsed);
596    }
597
598    /// Updates the elapsed time of flushing pending deallocate regions.
599    pub fn update_flush_pending_deallocate_regions_elapsed(&mut self, elapsed: Duration) {
600        self.volatile_ctx
601            .metrics
602            .update_flush_pending_deallocate_regions_elapsed(elapsed);
603    }
604
605    /// Updates the elapsed time of applying staging manifest.
606    pub fn update_apply_staging_manifest_elapsed(&mut self, elapsed: Duration) {
607        self.volatile_ctx
608            .metrics
609            .update_apply_staging_manifest_elapsed(elapsed);
610    }
611
612    /// Updates the elapsed time of remapping manifest.
613    pub fn update_remap_manifest_elapsed(&mut self, elapsed: Duration) {
614        self.volatile_ctx
615            .metrics
616            .update_remap_manifest_elapsed(elapsed);
617    }
618
619    /// Updates the elapsed time of updating metadata.
620    pub fn update_update_metadata_elapsed(&mut self, elapsed: Duration) {
621        self.volatile_ctx
622            .metrics
623            .update_update_metadata_elapsed(elapsed);
624    }
625}
626
627/// Returns the region routes of the given table route value.
628///
629/// Abort:
630/// - Table route value is not physical.
631pub fn region_routes(
632    table_id: TableId,
633    table_route_value: &TableRouteValue,
634) -> Result<&Vec<RegionRoute>> {
635    table_route_value
636        .region_routes()
637        .with_context(|_| error::UnexpectedLogicalRouteTableSnafu {
638            err_msg: format!(
639                "TableRoute({:?}) is a non-physical TableRouteValue.",
640                table_id
641            ),
642        })
643}
644
645#[async_trait::async_trait]
646#[typetag::serde(tag = "repartition_group_state")]
647pub(crate) trait State: Sync + Send + Debug {
648    fn name(&self) -> &'static str {
649        let type_name = std::any::type_name::<Self>();
650        // short name
651        type_name.split("::").last().unwrap_or(type_name)
652    }
653
654    /// Yields the next [State] and [Status].
655    async fn next(
656        &mut self,
657        ctx: &mut Context,
658        procedure_ctx: &ProcedureContext,
659    ) -> Result<(Box<dyn State>, Status)>;
660
661    fn as_any(&self) -> &dyn Any;
662}
663
664#[cfg(test)]
665mod tests {
666    use std::assert_matches;
667    use std::sync::Arc;
668    use std::time::Duration;
669
670    use common_meta::key::TableMetadataManager;
671    use common_meta::kv_backend::test_util::MockKvBackendBuilder;
672    use common_meta::peer::Peer;
673    use common_meta::rpc::router::{Region, RegionRoute};
674    use common_procedure::{Context as ProcedureContext, Procedure, ProcedureId};
675    use common_procedure_test::MockContextProvider;
676    use partition::expr::PartitionExpr;
677    use store_api::storage::RegionId;
678
679    use super::{
680        Context, PersistentContext, RepartitionGroupProcedure, RepartitionStart, State,
681        region_routes,
682    };
683    use crate::error::Error;
684    use crate::procedure::repartition::dispatch::build_region_mapping;
685    use crate::procedure::repartition::group::apply_staging_manifest::ApplyStagingManifest;
686    use crate::procedure::repartition::group::enter_staging_region::EnterStagingRegion;
687    use crate::procedure::repartition::group::remap_manifest::RemapManifest;
688    use crate::procedure::repartition::group::repartition_start::RepartitionStart as GroupRepartitionStart;
689    use crate::procedure::repartition::group::sync_region::SyncRegion;
690    use crate::procedure::repartition::group::update_metadata::UpdateMetadata;
691    use crate::procedure::repartition::plan;
692    use crate::procedure::repartition::repartition_start::RepartitionStart as ParentRepartitionStart;
693    use crate::procedure::repartition::test_util::{
694        TestingEnv, new_persistent_context, range_expr,
695    };
696
697    struct GroupRollbackFixture {
698        context: Context,
699        original_region_routes: Vec<RegionRoute>,
700        next_state: Option<Box<dyn State>>,
701    }
702
703    async fn new_group_rollback_fixture(
704        original_region_routes: Vec<RegionRoute>,
705        from_exprs: Vec<PartitionExpr>,
706        to_exprs: Vec<PartitionExpr>,
707        sync_region: bool,
708    ) -> GroupRollbackFixture {
709        let env = TestingEnv::new();
710        let procedure_ctx = TestingEnv::procedure_context();
711        let table_id = 1024;
712        let mut next_region_number = 10;
713
714        env.create_physical_table_metadata(table_id, original_region_routes.clone())
715            .await;
716
717        let (_, physical_route) = env
718            .table_metadata_manager
719            .table_route_manager()
720            .get_physical_table_route(table_id)
721            .await
722            .unwrap();
723        let allocation_plans =
724            ParentRepartitionStart::build_plan(&physical_route, &from_exprs, &to_exprs).unwrap();
725        assert_eq!(allocation_plans.len(), 1);
726
727        let repartition_plan = plan::convert_allocation_plan_to_repartition_plan(
728            table_id,
729            &mut next_region_number,
730            &allocation_plans[0],
731        );
732        let region_mapping = build_region_mapping(
733            &repartition_plan.source_regions,
734            &repartition_plan.target_regions,
735            &repartition_plan.transition_map,
736        );
737        let persistent_context = PersistentContext::new(
738            repartition_plan.group_id,
739            table_id,
740            "test_catalog".to_string(),
741            "test_schema".to_string(),
742            repartition_plan.source_regions,
743            repartition_plan.target_regions,
744            region_mapping,
745            sync_region,
746            repartition_plan.allocated_region_ids,
747            repartition_plan.pending_deallocate_region_ids,
748            Duration::from_secs(120),
749        );
750        let mut context = env.create_context(persistent_context);
751        let (next_state, _) = GroupRepartitionStart
752            .next(&mut context, &procedure_ctx)
753            .await
754            .unwrap();
755
756        GroupRollbackFixture {
757            context,
758            original_region_routes,
759            next_state: Some(next_state),
760        }
761    }
762
763    async fn new_split_group_rollback_fixture(sync_region: bool) -> GroupRollbackFixture {
764        new_group_rollback_fixture(
765            vec![
766                new_region_route(RegionId::new(1024, 1), Some(range_expr("x", 0, 100))),
767                new_region_route(RegionId::new(1024, 2), Some(range_expr("x", 100, 200))),
768                new_region_route(RegionId::new(1024, 10), None),
769            ],
770            vec![range_expr("x", 0, 100)],
771            vec![range_expr("x", 0, 50), range_expr("x", 50, 100)],
772            sync_region,
773        )
774        .await
775    }
776
777    async fn new_merge_group_rollback_fixture(sync_region: bool) -> GroupRollbackFixture {
778        new_group_rollback_fixture(
779            vec![
780                new_region_route(RegionId::new(1024, 1), Some(range_expr("x", 0, 100))),
781                new_region_route(RegionId::new(1024, 2), Some(range_expr("x", 100, 200))),
782                new_region_route(RegionId::new(1024, 3), Some(range_expr("x", 200, 300))),
783            ],
784            vec![range_expr("x", 0, 100), range_expr("x", 100, 200)],
785            vec![range_expr("x", 0, 200)],
786            sync_region,
787        )
788        .await
789    }
790
791    async fn stage_metadata(context: &mut Context) {
792        UpdateMetadata::ApplyStaging
793            .apply_staging_regions(context)
794            .await
795            .unwrap();
796    }
797
798    fn new_region_route(region_id: RegionId, partition_expr: Option<PartitionExpr>) -> RegionRoute {
799        RegionRoute {
800            region: Region {
801                id: region_id,
802                partition_expr: partition_expr
803                    .map(|expr| expr.as_json_str().unwrap())
804                    .unwrap_or_default(),
805                ..Default::default()
806            },
807            leader_peer: Some(Peer::empty(1)),
808            ..Default::default()
809        }
810    }
811
812    #[tokio::test]
813    async fn test_get_table_route_value_not_found_error() {
814        let env = TestingEnv::new();
815        let persistent_context = new_persistent_context(1024, vec![], vec![]);
816        let ctx = env.create_context(persistent_context);
817        let err = ctx.get_table_route_value().await.unwrap_err();
818        assert_matches!(err, Error::TableRouteNotFound { .. });
819        assert!(!err.is_retryable());
820    }
821
822    #[tokio::test]
823    async fn test_get_table_route_value_retry_error() {
824        let kv = MockKvBackendBuilder::default()
825            .range_fn(Arc::new(|_| {
826                common_meta::error::UnexpectedSnafu {
827                    err_msg: "mock err",
828                }
829                .fail()
830            }))
831            .build()
832            .unwrap();
833        let mut env = TestingEnv::new();
834        env.table_metadata_manager = Arc::new(TableMetadataManager::new(Arc::new(kv)));
835        let persistent_context = new_persistent_context(1024, vec![], vec![]);
836        let ctx = env.create_context(persistent_context);
837        let err = ctx.get_table_route_value().await.unwrap_err();
838        assert!(err.is_retryable());
839    }
840
841    #[tokio::test]
842    async fn test_get_datanode_table_value_retry_error() {
843        let kv = MockKvBackendBuilder::default()
844            .range_fn(Arc::new(|_| {
845                common_meta::error::UnexpectedSnafu {
846                    err_msg: "mock err",
847                }
848                .fail()
849            }))
850            .build()
851            .unwrap();
852        let mut env = TestingEnv::new();
853        env.table_metadata_manager = Arc::new(TableMetadataManager::new(Arc::new(kv)));
854        let persistent_context = new_persistent_context(1024, vec![], vec![]);
855        let ctx = env.create_context(persistent_context);
856        let err = ctx.get_datanode_table_value(1024, 1).await.unwrap_err();
857        assert!(err.is_retryable());
858    }
859
860    #[tokio::test]
861    async fn test_group_rollback_supported() {
862        let env = TestingEnv::new();
863        let persistent_context = new_persistent_context(1024, vec![], vec![]);
864        let procedure = RepartitionGroupProcedure {
865            state: Box::new(RepartitionStart),
866            context: env.create_context(persistent_context),
867        };
868
869        assert!(procedure.rollback_supported());
870    }
871
872    #[tokio::test]
873    async fn test_group_rollback_is_noop_before_apply_staging() {
874        let env = TestingEnv::new();
875        let persistent_context = new_persistent_context(1024, vec![], vec![]);
876        let ctx = env.create_context(persistent_context.clone());
877        let mut procedure = RepartitionGroupProcedure {
878            state: Box::new(RepartitionStart),
879            context: ctx,
880        };
881        let provider = Arc::new(MockContextProvider::new(Default::default()));
882        let procedure_ctx = ProcedureContext {
883            procedure_id: ProcedureId::random(),
884            provider,
885        };
886
887        procedure.rollback(&procedure_ctx).await.unwrap();
888
889        assert!(procedure.state.as_any().is::<RepartitionStart>());
890        assert_eq!(procedure.context.persistent_ctx, persistent_context);
891    }
892
893    async fn assert_noop_rollback(
894        fixture: GroupRollbackFixture,
895        state: Box<dyn State>,
896        assert_state: impl FnOnce(&dyn State),
897    ) {
898        let original_region_routes = fixture.original_region_routes.clone();
899        let procedure_ctx = TestingEnv::procedure_context();
900        let mut procedure = RepartitionGroupProcedure {
901            state,
902            context: fixture.context,
903        };
904
905        procedure.rollback(&procedure_ctx).await.unwrap();
906
907        assert_state(&*procedure.state);
908        let table_route_value = procedure
909            .context
910            .get_table_route_value()
911            .await
912            .unwrap()
913            .into_inner();
914        let region_routes = region_routes(
915            procedure.context.persistent_ctx.table_id,
916            &table_route_value,
917        )
918        .unwrap();
919        assert_eq!(region_routes.clone(), original_region_routes);
920    }
921
922    async fn assert_metadata_rollback_restores_table_route(
923        mut fixture: GroupRollbackFixture,
924        state: Box<dyn State>,
925    ) {
926        let original_region_routes = fixture.original_region_routes.clone();
927        let procedure_ctx = TestingEnv::procedure_context();
928        stage_metadata(&mut fixture.context).await;
929        let mut procedure = RepartitionGroupProcedure {
930            state,
931            context: fixture.context,
932        };
933
934        procedure.rollback(&procedure_ctx).await.unwrap();
935
936        let table_route_value = procedure
937            .context
938            .get_table_route_value()
939            .await
940            .unwrap()
941            .into_inner();
942        let region_routes = region_routes(
943            procedure.context.persistent_ctx.table_id,
944            &table_route_value,
945        )
946        .unwrap();
947        assert_eq!(region_routes.clone(), original_region_routes);
948    }
949
950    #[tokio::test]
951    async fn test_group_rollback_is_noop_in_sync_region() {
952        let mut fixture = new_split_group_rollback_fixture(true).await;
953        assert!(
954            fixture
955                .next_state
956                .as_ref()
957                .unwrap()
958                .as_any()
959                .is::<SyncRegion>()
960        );
961        let state = fixture.next_state.take().unwrap();
962
963        assert_noop_rollback(fixture, state, |state| {
964            assert!(state.as_any().is::<SyncRegion>());
965        })
966        .await;
967    }
968
969    #[tokio::test]
970    async fn test_group_rollback_is_noop_in_exit_staging() {
971        let fixture = new_split_group_rollback_fixture(false).await;
972
973        assert_noop_rollback(fixture, Box::new(UpdateMetadata::ExitStaging), |state| {
974            assert!(state.as_any().is::<UpdateMetadata>());
975            assert!(matches!(
976                state.as_any().downcast_ref::<UpdateMetadata>(),
977                Some(UpdateMetadata::ExitStaging)
978            ));
979        })
980        .await;
981    }
982
983    #[tokio::test]
984    async fn test_group_rollback_restores_split_routes_from_apply_staging() {
985        let fixture = new_split_group_rollback_fixture(false).await;
986        assert_metadata_rollback_restores_table_route(
987            fixture,
988            Box::new(UpdateMetadata::ApplyStaging),
989        )
990        .await;
991    }
992
993    #[tokio::test]
994    async fn test_group_rollback_restores_split_routes_from_enter_staging_region() {
995        let fixture = new_split_group_rollback_fixture(false).await;
996        assert_metadata_rollback_restores_table_route(fixture, Box::new(EnterStagingRegion)).await;
997    }
998
999    #[tokio::test]
1000    async fn test_group_rollback_restores_split_routes_from_remap_manifest() {
1001        let fixture = new_split_group_rollback_fixture(false).await;
1002        assert_metadata_rollback_restores_table_route(fixture, Box::new(RemapManifest)).await;
1003    }
1004
1005    #[tokio::test]
1006    async fn test_group_rollback_restores_split_routes_from_apply_staging_manifest() {
1007        let fixture = new_split_group_rollback_fixture(false).await;
1008        assert_metadata_rollback_restores_table_route(fixture, Box::new(ApplyStagingManifest))
1009            .await;
1010    }
1011
1012    #[tokio::test]
1013    async fn test_group_rollback_restores_merge_routes_and_is_idempotent() {
1014        let mut fixture = new_merge_group_rollback_fixture(false).await;
1015        let original_region_routes = fixture.original_region_routes.clone();
1016        let procedure_ctx = TestingEnv::procedure_context();
1017        stage_metadata(&mut fixture.context).await;
1018        let mut procedure = RepartitionGroupProcedure {
1019            state: Box::new(UpdateMetadata::ApplyStaging),
1020            context: fixture.context,
1021        };
1022
1023        procedure.rollback(&procedure_ctx).await.unwrap();
1024        let table_route_value = procedure
1025            .context
1026            .get_table_route_value()
1027            .await
1028            .unwrap()
1029            .into_inner();
1030        let once = region_routes(
1031            procedure.context.persistent_ctx.table_id,
1032            &table_route_value,
1033        )
1034        .unwrap()
1035        .clone();
1036        procedure.rollback(&procedure_ctx).await.unwrap();
1037        let table_route_value = procedure
1038            .context
1039            .get_table_route_value()
1040            .await
1041            .unwrap()
1042            .into_inner();
1043        let twice = region_routes(
1044            procedure.context.persistent_ctx.table_id,
1045            &table_route_value,
1046        )
1047        .unwrap()
1048        .clone();
1049
1050        assert_eq!(once, original_region_routes);
1051        assert_eq!(once, twice);
1052    }
1053}