common_meta/key/
table_route.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::fmt::Display;
17use std::sync::Arc;
18
19use serde::{Deserialize, Deserializer, Serialize};
20use snafu::{OptionExt, ResultExt, ensure};
21use store_api::storage::{RegionId, RegionNumber};
22use table::metadata::TableId;
23
24use crate::error::{
25    InvalidMetadataSnafu, MetadataCorruptionSnafu, RegionNotFoundSnafu, Result, SerdeJsonSnafu,
26    TableRouteNotFoundSnafu, UnexpectedLogicalRouteTableSnafu,
27};
28use crate::key::node_address::{NodeAddressKey, NodeAddressValue};
29use crate::key::txn_helper::TxnOpGetResponseSet;
30use crate::key::{
31    DeserializedValueWithBytes, MetadataKey, MetadataValue, RegionDistribution,
32    TABLE_ROUTE_KEY_PATTERN, TABLE_ROUTE_PREFIX,
33};
34use crate::kv_backend::KvBackendRef;
35use crate::kv_backend::txn::Txn;
36use crate::rpc::router::{RegionRoute, region_distribution};
37use crate::rpc::store::BatchGetRequest;
38
39/// The key stores table routes
40///
41/// The layout: `__table_route/{table_id}`.
42#[derive(Debug, PartialEq)]
43pub struct TableRouteKey {
44    pub table_id: TableId,
45}
46
47impl TableRouteKey {
48    pub fn new(table_id: TableId) -> Self {
49        Self { table_id }
50    }
51
52    /// Returns the range prefix of the table route key.
53    pub fn range_prefix() -> Vec<u8> {
54        format!("{}/", TABLE_ROUTE_PREFIX).into_bytes()
55    }
56}
57
58#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
59#[serde(tag = "type", rename_all = "snake_case")]
60pub enum TableRouteValue {
61    Physical(PhysicalTableRouteValue),
62    Logical(LogicalTableRouteValue),
63}
64
65#[derive(Debug, PartialEq, Serialize, Clone, Default)]
66pub struct PhysicalTableRouteValue {
67    // The region routes of the table.
68    pub region_routes: Vec<RegionRoute>,
69    // Tracks the highest region number ever allocated for the table.
70    // This value only increases: adding a region updates it if needed,
71    // and dropping regions does not decrease it.
72    pub max_region_number: RegionNumber,
73    // The version of the table route.
74    version: u64,
75}
76
77impl<'de> Deserialize<'de> for PhysicalTableRouteValue {
78    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        #[derive(Deserialize)]
83        struct Helper {
84            region_routes: Vec<RegionRoute>,
85            #[serde(default)]
86            max_region_number: Option<RegionNumber>,
87            version: u64,
88        }
89
90        let mut helper = Helper::deserialize(deserializer)?;
91        // If the max region number is not provided, we will calculate it from the region routes.
92        if helper.max_region_number.is_none() {
93            let max_region = helper
94                .region_routes
95                .iter()
96                .map(|r| r.region.id.region_number())
97                .max()
98                .unwrap_or_default();
99            helper.max_region_number = Some(max_region);
100        }
101
102        Ok(PhysicalTableRouteValue {
103            region_routes: helper.region_routes,
104            max_region_number: helper.max_region_number.unwrap_or_default(),
105            version: helper.version,
106        })
107    }
108}
109
110#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
111pub struct LogicalTableRouteValue {
112    physical_table_id: TableId,
113}
114
115impl TableRouteValue {
116    /// Returns a [TableRouteValue::Physical] if `table_id` equals `physical_table_id`.
117    /// Otherwise returns a [TableRouteValue::Logical].
118    pub(crate) fn new(
119        table_id: TableId,
120        physical_table_id: TableId,
121        region_routes: Vec<RegionRoute>,
122    ) -> Self {
123        if table_id == physical_table_id {
124            TableRouteValue::physical(region_routes)
125        } else {
126            TableRouteValue::logical(physical_table_id)
127        }
128    }
129
130    pub fn physical(region_routes: Vec<RegionRoute>) -> Self {
131        Self::Physical(PhysicalTableRouteValue::new(region_routes))
132    }
133
134    pub fn logical(physical_table_id: TableId) -> Self {
135        Self::Logical(LogicalTableRouteValue::new(physical_table_id))
136    }
137
138    /// Returns a new version [TableRouteValue] with `region_routes`.
139    pub fn update(&self, region_routes: Vec<RegionRoute>) -> Result<Self> {
140        ensure!(
141            self.is_physical(),
142            UnexpectedLogicalRouteTableSnafu {
143                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
144            }
145        );
146        let physical_table_route = self.as_physical_table_route_ref();
147        let original_max_region_number = physical_table_route.max_region_number;
148        let new_max_region_number = region_routes
149            .iter()
150            .map(|r| r.region.id.region_number())
151            .max()
152            .unwrap_or_default();
153        let version = physical_table_route.version;
154        Ok(Self::Physical(PhysicalTableRouteValue {
155            region_routes,
156            // If region routes are added, we will update the max region number.
157            // If region routes are removed, we will keep the original max region number.
158            max_region_number: original_max_region_number.max(new_max_region_number),
159            version: version + 1,
160        }))
161    }
162
163    /// Returns the version.
164    ///
165    /// For test purpose.
166    #[cfg(any(test, feature = "testing"))]
167    pub fn version(&self) -> Result<u64> {
168        ensure!(
169            self.is_physical(),
170            UnexpectedLogicalRouteTableSnafu {
171                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
172            }
173        );
174        Ok(self.as_physical_table_route_ref().version)
175    }
176
177    /// Returns the corresponding [RegionRoute], returns `None` if it's the specific region is not found.
178    ///
179    /// Note: It throws an error if it's a logical table
180    pub fn region_route(&self, region_id: RegionId) -> Result<Option<RegionRoute>> {
181        ensure!(
182            self.is_physical(),
183            UnexpectedLogicalRouteTableSnafu {
184                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
185            }
186        );
187        Ok(self
188            .as_physical_table_route_ref()
189            .region_routes
190            .iter()
191            .find(|route| route.region.id == region_id)
192            .cloned())
193    }
194
195    /// Returns true if it's [TableRouteValue::Physical].
196    pub fn is_physical(&self) -> bool {
197        matches!(self, TableRouteValue::Physical(_))
198    }
199
200    /// Gets the [RegionRoute]s of this [TableRouteValue::Physical].
201    pub fn region_routes(&self) -> Result<&Vec<RegionRoute>> {
202        ensure!(
203            self.is_physical(),
204            UnexpectedLogicalRouteTableSnafu {
205                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
206            }
207        );
208        Ok(&self.as_physical_table_route_ref().region_routes)
209    }
210
211    /// Returns the max region number of this [TableRouteValue::Physical].
212    ///
213    /// # Panic
214    /// If it is not the [`PhysicalTableRouteValue`].
215    pub fn max_region_number(&self) -> Result<RegionNumber> {
216        ensure!(
217            self.is_physical(),
218            UnexpectedLogicalRouteTableSnafu {
219                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
220            }
221        );
222        Ok(self.as_physical_table_route_ref().max_region_number)
223    }
224
225    /// Returns the reference of [`PhysicalTableRouteValue`].
226    ///
227    /// # Panic
228    /// If it is not the [`PhysicalTableRouteValue`].
229    fn as_physical_table_route_ref(&self) -> &PhysicalTableRouteValue {
230        match self {
231            TableRouteValue::Physical(x) => x,
232            _ => unreachable!("Mistakenly been treated as a Physical TableRoute: {self:?}"),
233        }
234    }
235
236    /// Converts to [`PhysicalTableRouteValue`].
237    ///
238    /// # Panic
239    /// If it is not the [`PhysicalTableRouteValue`].
240    pub fn into_physical_table_route(self) -> PhysicalTableRouteValue {
241        match self {
242            TableRouteValue::Physical(x) => x,
243            _ => unreachable!("Mistakenly been treated as a Physical TableRoute: {self:?}"),
244        }
245    }
246
247    /// Converts to [`LogicalTableRouteValue`].
248    ///
249    /// # Panic
250    /// If it is not the [`LogicalTableRouteValue`].
251    pub fn into_logical_table_route(self) -> LogicalTableRouteValue {
252        match self {
253            TableRouteValue::Logical(x) => x,
254            _ => unreachable!("Mistakenly been treated as a Logical TableRoute: {self:?}"),
255        }
256    }
257
258    pub fn region_numbers(&self) -> Vec<RegionNumber> {
259        match self {
260            TableRouteValue::Physical(x) => x
261                .region_routes
262                .iter()
263                .map(|region_route| region_route.region.id.region_number())
264                .collect(),
265            TableRouteValue::Logical(_) => {
266                vec![]
267            }
268        }
269    }
270}
271
272impl MetadataValue for TableRouteValue {
273    fn try_from_raw_value(raw_value: &[u8]) -> Result<Self> {
274        let r = serde_json::from_slice::<TableRouteValue>(raw_value);
275        match r {
276            // Compatible with old TableRouteValue.
277            Err(e) if e.is_data() => Ok(Self::Physical(
278                serde_json::from_slice::<PhysicalTableRouteValue>(raw_value)
279                    .context(SerdeJsonSnafu)?,
280            )),
281            Ok(x) => Ok(x),
282            Err(e) => Err(e).context(SerdeJsonSnafu),
283        }
284    }
285
286    fn try_as_raw_value(&self) -> Result<Vec<u8>> {
287        serde_json::to_vec(self).context(SerdeJsonSnafu)
288    }
289}
290
291impl PhysicalTableRouteValue {
292    pub fn new(region_routes: Vec<RegionRoute>) -> Self {
293        let max_region_number = region_routes
294            .iter()
295            .map(|r| r.region.id.region_number())
296            .max()
297            .unwrap_or_default();
298        Self {
299            region_routes,
300            max_region_number,
301            version: 0,
302        }
303    }
304}
305
306impl LogicalTableRouteValue {
307    pub fn new(physical_table_id: TableId) -> Self {
308        Self { physical_table_id }
309    }
310
311    pub fn physical_table_id(&self) -> TableId {
312        self.physical_table_id
313    }
314}
315
316impl MetadataKey<'_, TableRouteKey> for TableRouteKey {
317    fn to_bytes(&self) -> Vec<u8> {
318        self.to_string().into_bytes()
319    }
320
321    fn from_bytes(bytes: &[u8]) -> Result<TableRouteKey> {
322        let key = std::str::from_utf8(bytes).map_err(|e| {
323            InvalidMetadataSnafu {
324                err_msg: format!(
325                    "TableRouteKey '{}' is not a valid UTF8 string: {e}",
326                    String::from_utf8_lossy(bytes)
327                ),
328            }
329            .build()
330        })?;
331        let captures = TABLE_ROUTE_KEY_PATTERN
332            .captures(key)
333            .context(InvalidMetadataSnafu {
334                err_msg: format!("Invalid TableRouteKey '{key}'"),
335            })?;
336        // Safety: pass the regex check above
337        let table_id = captures[1].parse::<TableId>().unwrap();
338        Ok(TableRouteKey { table_id })
339    }
340}
341
342impl Display for TableRouteKey {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        write!(f, "{}/{}", TABLE_ROUTE_PREFIX, self.table_id)
345    }
346}
347
348pub type TableRouteManagerRef = Arc<TableRouteManager>;
349
350pub struct TableRouteManager {
351    storage: TableRouteStorage,
352}
353
354impl TableRouteManager {
355    pub fn new(kv_backend: KvBackendRef) -> Self {
356        Self {
357            storage: TableRouteStorage::new(kv_backend),
358        }
359    }
360
361    /// Returns the [TableId] recursively.
362    ///
363    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
364    /// - the table(`logical_or_physical_table_id`) does not exist.
365    pub async fn get_physical_table_id(
366        &self,
367        logical_or_physical_table_id: TableId,
368    ) -> Result<TableId> {
369        let table_route = self
370            .storage
371            .get_inner(logical_or_physical_table_id)
372            .await?
373            .context(TableRouteNotFoundSnafu {
374                table_id: logical_or_physical_table_id,
375            })?;
376
377        match table_route {
378            TableRouteValue::Physical(_) => Ok(logical_or_physical_table_id),
379            TableRouteValue::Logical(x) => Ok(x.physical_table_id()),
380        }
381    }
382
383    /// Returns the [TableRouteValue::Physical] recursively.
384    ///
385    /// Returns a [TableRouteNotFound](error::Error::TableRouteNotFound) Error if:
386    /// - the physical table(`logical_or_physical_table_id`) does not exist
387    /// - the corresponding physical table of the logical table(`logical_or_physical_table_id`) does not exist.
388    pub async fn get_physical_table_route(
389        &self,
390        logical_or_physical_table_id: TableId,
391    ) -> Result<(TableId, PhysicalTableRouteValue)> {
392        let table_route = self
393            .storage
394            .get(logical_or_physical_table_id)
395            .await?
396            .context(TableRouteNotFoundSnafu {
397                table_id: logical_or_physical_table_id,
398            })?;
399
400        match table_route {
401            TableRouteValue::Physical(x) => Ok((logical_or_physical_table_id, x)),
402            TableRouteValue::Logical(x) => {
403                let physical_table_id = x.physical_table_id();
404                let physical_table_route = self.storage.get(physical_table_id).await?.context(
405                    TableRouteNotFoundSnafu {
406                        table_id: physical_table_id,
407                    },
408                )?;
409                let physical_table_route = physical_table_route.into_physical_table_route();
410                Ok((physical_table_id, physical_table_route))
411            }
412        }
413    }
414
415    /// Returns the [TableRouteValue::Physical] recursively.
416    ///
417    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
418    /// - one of the logical tables corresponding to the physical table does not exist.
419    ///
420    /// **Notes**: it may return a subset of `logical_or_physical_table_ids`.
421    pub async fn batch_get_physical_table_routes(
422        &self,
423        logical_or_physical_table_ids: &[TableId],
424    ) -> Result<HashMap<TableId, PhysicalTableRouteValue>> {
425        let table_routes = self
426            .storage
427            .batch_get(logical_or_physical_table_ids)
428            .await?;
429        // Returns a subset of `logical_or_physical_table_ids`.
430        let table_routes = table_routes
431            .into_iter()
432            .zip(logical_or_physical_table_ids)
433            .filter_map(|(route, id)| route.map(|route| (*id, route)))
434            .collect::<HashMap<_, _>>();
435
436        let mut physical_table_routes = HashMap::with_capacity(table_routes.len());
437        let mut logical_table_ids = HashMap::with_capacity(table_routes.len());
438
439        for (table_id, table_route) in table_routes {
440            match table_route {
441                TableRouteValue::Physical(x) => {
442                    physical_table_routes.insert(table_id, x);
443                }
444                TableRouteValue::Logical(x) => {
445                    logical_table_ids.insert(table_id, x.physical_table_id());
446                }
447            }
448        }
449
450        if logical_table_ids.is_empty() {
451            return Ok(physical_table_routes);
452        }
453
454        // Finds the logical tables corresponding to the physical tables.
455        let physical_table_ids = logical_table_ids
456            .values()
457            .cloned()
458            .collect::<HashSet<_>>()
459            .into_iter()
460            .collect::<Vec<_>>();
461        let table_routes = self
462            .table_route_storage()
463            .batch_get(&physical_table_ids)
464            .await?;
465        let table_routes = table_routes
466            .into_iter()
467            .zip(physical_table_ids)
468            .filter_map(|(route, id)| route.map(|route| (id, route)))
469            .collect::<HashMap<_, _>>();
470
471        for (logical_table_id, physical_table_id) in logical_table_ids {
472            let table_route =
473                table_routes
474                    .get(&physical_table_id)
475                    .context(TableRouteNotFoundSnafu {
476                        table_id: physical_table_id,
477                    })?;
478            match table_route {
479                TableRouteValue::Physical(x) => {
480                    physical_table_routes.insert(logical_table_id, x.clone());
481                }
482                TableRouteValue::Logical(x) => {
483                    // Never get here, because we use a physical table id cannot obtain a logical table.
484                    MetadataCorruptionSnafu {
485                        err_msg: format!(
486                            "logical table {} {:?} cannot be resolved to a physical table.",
487                            logical_table_id, x
488                        ),
489                    }
490                    .fail()?;
491                }
492            }
493        }
494
495        Ok(physical_table_routes)
496    }
497
498    /// Returns [`RegionDistribution`] of the table(`table_id`).
499    pub async fn get_region_distribution(
500        &self,
501        table_id: TableId,
502    ) -> Result<Option<RegionDistribution>> {
503        self.storage
504            .get(table_id)
505            .await?
506            .map(|table_route| Ok(region_distribution(table_route.region_routes()?)))
507            .transpose()
508    }
509
510    /// Sets the staging state for a specific region.
511    ///
512    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
513    /// - the table does not exist
514    /// - the region is not found in the table
515    pub async fn set_region_staging_state(
516        &self,
517        region_id: store_api::storage::RegionId,
518        staging: bool,
519    ) -> Result<()> {
520        let table_id = region_id.table_id();
521
522        // Get current table route with raw bytes for CAS operation
523        let current_table_route = self
524            .storage
525            .get_with_raw_bytes(table_id)
526            .await?
527            .context(TableRouteNotFoundSnafu { table_id })?;
528
529        // Clone the current route value and update the specific region
530        let new_table_route = current_table_route.inner.clone();
531
532        // Only physical tables have region routes
533        ensure!(
534            new_table_route.is_physical(),
535            UnexpectedLogicalRouteTableSnafu {
536                err_msg: format!("Cannot set staging state for logical table {table_id}"),
537            }
538        );
539
540        let region_routes = new_table_route.region_routes()?.clone();
541        let mut updated_routes = region_routes.clone();
542
543        // Find and update the specific region
544        // TODO(ruihang): maybe update them in one transaction
545        let mut region_found = false;
546        for route in &mut updated_routes {
547            if route.region.id == region_id {
548                if staging {
549                    route.set_leader_staging();
550                } else {
551                    route.clear_leader_staging();
552                }
553                region_found = true;
554                break;
555            }
556        }
557
558        ensure!(region_found, RegionNotFoundSnafu { region_id });
559
560        // Create new table route with updated region routes
561        let updated_table_route = new_table_route.update(updated_routes)?;
562
563        // Execute atomic update
564        let (txn, _) =
565            self.storage
566                .build_update_txn(table_id, &current_table_route, &updated_table_route)?;
567
568        let result = self.storage.kv_backend.txn(txn).await?;
569
570        ensure!(
571            result.succeeded,
572            MetadataCorruptionSnafu {
573                err_msg: format!(
574                    "Failed to update staging state for region {}: CAS operation failed",
575                    region_id
576                ),
577            }
578        );
579
580        Ok(())
581    }
582
583    /// Checks if a specific region is in staging state.
584    ///
585    /// Returns false if the table/region doesn't exist.
586    pub async fn is_region_staging(&self, region_id: store_api::storage::RegionId) -> Result<bool> {
587        let table_id = region_id.table_id();
588
589        let table_route = self.storage.get(table_id).await?;
590
591        match table_route {
592            Some(route) if route.is_physical() => {
593                let region_routes = route.region_routes()?;
594                for route in region_routes {
595                    if route.region.id == region_id {
596                        return Ok(route.is_leader_staging());
597                    }
598                }
599                Ok(false)
600            }
601            _ => Ok(false),
602        }
603    }
604
605    /// Returns low-level APIs.
606    pub fn table_route_storage(&self) -> &TableRouteStorage {
607        &self.storage
608    }
609}
610
611/// Low-level operations of [TableRouteValue].
612pub struct TableRouteStorage {
613    kv_backend: KvBackendRef,
614}
615
616pub type TableRouteValueDecodeResult = Result<Option<DeserializedValueWithBytes<TableRouteValue>>>;
617
618impl TableRouteStorage {
619    pub fn new(kv_backend: KvBackendRef) -> Self {
620        Self { kv_backend }
621    }
622
623    /// Builds a create table route transaction,
624    /// it expected the `__table_route/{table_id}` wasn't occupied.
625    pub fn build_create_txn(
626        &self,
627        table_id: TableId,
628        table_route_value: &TableRouteValue,
629    ) -> Result<(
630        Txn,
631        impl FnOnce(&mut TxnOpGetResponseSet) -> TableRouteValueDecodeResult + use<>,
632    )> {
633        let key = TableRouteKey::new(table_id);
634        let raw_key = key.to_bytes();
635
636        let txn = Txn::put_if_not_exists(raw_key.clone(), table_route_value.try_as_raw_value()?);
637
638        Ok((
639            txn,
640            TxnOpGetResponseSet::decode_with(TxnOpGetResponseSet::filter(raw_key)),
641        ))
642    }
643
644    // TODO(LFC): restore its original visibility after some test utility codes are refined
645    /// Builds a update table route transaction,
646    /// it expected the remote value equals the `current_table_route_value`.
647    /// It retrieves the latest value if the comparing failed.
648    pub fn build_update_txn(
649        &self,
650        table_id: TableId,
651        current_table_route_value: &DeserializedValueWithBytes<TableRouteValue>,
652        new_table_route_value: &TableRouteValue,
653    ) -> Result<(
654        Txn,
655        impl FnOnce(&mut TxnOpGetResponseSet) -> TableRouteValueDecodeResult + use<>,
656    )> {
657        let key = TableRouteKey::new(table_id);
658        let raw_key = key.to_bytes();
659        let raw_value = current_table_route_value.get_raw_bytes();
660        let new_raw_value: Vec<u8> = new_table_route_value.try_as_raw_value()?;
661
662        let txn = Txn::compare_and_put(raw_key.clone(), raw_value, new_raw_value);
663
664        Ok((
665            txn,
666            TxnOpGetResponseSet::decode_with(TxnOpGetResponseSet::filter(raw_key)),
667        ))
668    }
669
670    /// Returns the [`TableRouteValue`].
671    pub async fn get(&self, table_id: TableId) -> Result<Option<TableRouteValue>> {
672        let mut table_route = self.get_inner(table_id).await?;
673        if let Some(table_route) = &mut table_route {
674            self.remap_route_address(table_route).await?;
675        };
676
677        Ok(table_route)
678    }
679
680    async fn get_inner(&self, table_id: TableId) -> Result<Option<TableRouteValue>> {
681        let key = TableRouteKey::new(table_id);
682        self.kv_backend
683            .get(&key.to_bytes())
684            .await?
685            .map(|kv| TableRouteValue::try_from_raw_value(&kv.value))
686            .transpose()
687    }
688
689    /// Returns the [`TableRouteValue`] wrapped with [`DeserializedValueWithBytes`].
690    pub async fn get_with_raw_bytes(
691        &self,
692        table_id: TableId,
693    ) -> Result<Option<DeserializedValueWithBytes<TableRouteValue>>> {
694        let mut table_route = self.get_with_raw_bytes_inner(table_id).await?;
695        if let Some(table_route) = &mut table_route {
696            self.remap_route_address(table_route).await?;
697        };
698
699        Ok(table_route)
700    }
701
702    async fn get_with_raw_bytes_inner(
703        &self,
704        table_id: TableId,
705    ) -> Result<Option<DeserializedValueWithBytes<TableRouteValue>>> {
706        let key = TableRouteKey::new(table_id);
707        self.kv_backend
708            .get(&key.to_bytes())
709            .await?
710            .map(|kv| DeserializedValueWithBytes::from_inner_slice(&kv.value))
711            .transpose()
712    }
713
714    /// Returns batch of [`TableRouteValue`] that respects the order of `table_ids`.
715    pub async fn batch_get(&self, table_ids: &[TableId]) -> Result<Vec<Option<TableRouteValue>>> {
716        let raw_table_routes = self.batch_get_inner(table_ids).await?;
717
718        Ok(raw_table_routes
719            .into_iter()
720            .map(|v| v.map(|x| x.inner))
721            .collect())
722    }
723
724    /// Returns batch of [`TableRouteValue`] wrapped with [`DeserializedValueWithBytes`].
725    ///
726    /// The return value is a vector of [`Option<DeserializedValueWithBytes<TableRouteValue>>`].
727    /// Note: This method remaps the addresses of the table routes, but does not update their raw byte representations.
728    pub async fn batch_get_with_raw_bytes(
729        &self,
730        table_ids: &[TableId],
731    ) -> Result<Vec<Option<DeserializedValueWithBytes<TableRouteValue>>>> {
732        let mut raw_table_routes = self.batch_get_inner(table_ids).await?;
733        self.remap_routes_addresses(&mut raw_table_routes).await?;
734
735        Ok(raw_table_routes)
736    }
737
738    async fn batch_get_inner(
739        &self,
740        table_ids: &[TableId],
741    ) -> Result<Vec<Option<DeserializedValueWithBytes<TableRouteValue>>>> {
742        let keys = table_ids
743            .iter()
744            .map(|id| TableRouteKey::new(*id).to_bytes())
745            .collect::<Vec<_>>();
746        let resp = self
747            .kv_backend
748            .batch_get(BatchGetRequest { keys: keys.clone() })
749            .await?;
750
751        let kvs = resp
752            .kvs
753            .into_iter()
754            .map(|kv| (kv.key, kv.value))
755            .collect::<HashMap<_, _>>();
756        keys.into_iter()
757            .map(|key| {
758                if let Some(value) = kvs.get(&key) {
759                    Ok(Some(DeserializedValueWithBytes::from_inner_slice(value)?))
760                } else {
761                    Ok(None)
762                }
763            })
764            .collect()
765    }
766
767    async fn remap_routes_addresses(
768        &self,
769        table_routes: &mut [Option<DeserializedValueWithBytes<TableRouteValue>>],
770    ) -> Result<()> {
771        let keys = table_routes
772            .iter()
773            .flat_map(|table_route| {
774                table_route
775                    .as_ref()
776                    .map(|x| extract_address_keys(&x.inner))
777                    .unwrap_or_default()
778            })
779            .collect::<HashSet<_>>()
780            .into_iter()
781            .collect();
782        let node_addrs = self.get_node_addresses(keys).await?;
783        for table_route in table_routes.iter_mut().flatten() {
784            set_addresses(&node_addrs, table_route)?;
785        }
786
787        Ok(())
788    }
789
790    async fn remap_route_address(&self, table_route: &mut TableRouteValue) -> Result<()> {
791        let keys = extract_address_keys(table_route).into_iter().collect();
792        let node_addrs = self.get_node_addresses(keys).await?;
793        set_addresses(&node_addrs, table_route)?;
794
795        Ok(())
796    }
797
798    async fn get_node_addresses(
799        &self,
800        keys: Vec<Vec<u8>>,
801    ) -> Result<HashMap<u64, NodeAddressValue>> {
802        if keys.is_empty() {
803            return Ok(HashMap::default());
804        }
805
806        self.kv_backend
807            .batch_get(BatchGetRequest { keys })
808            .await?
809            .kvs
810            .into_iter()
811            .map(|kv| {
812                let node_id = NodeAddressKey::from_bytes(&kv.key)?.node_id;
813                let node_addr = NodeAddressValue::try_from_raw_value(&kv.value)?;
814                Ok((node_id, node_addr))
815            })
816            .collect()
817    }
818}
819
820fn set_addresses(
821    node_addrs: &HashMap<u64, NodeAddressValue>,
822    table_route: &mut TableRouteValue,
823) -> Result<()> {
824    let TableRouteValue::Physical(physical_table_route) = table_route else {
825        return Ok(());
826    };
827
828    for region_route in &mut physical_table_route.region_routes {
829        if let Some(leader) = &mut region_route.leader_peer
830            && let Some(node_addr) = node_addrs.get(&leader.id)
831        {
832            leader.addr = node_addr.peer.addr.clone();
833        }
834        for follower in &mut region_route.follower_peers {
835            if let Some(node_addr) = node_addrs.get(&follower.id) {
836                follower.addr = node_addr.peer.addr.clone();
837            }
838        }
839    }
840
841    Ok(())
842}
843
844fn extract_address_keys(table_route: &TableRouteValue) -> HashSet<Vec<u8>> {
845    let TableRouteValue::Physical(physical_table_route) = table_route else {
846        return HashSet::default();
847    };
848
849    physical_table_route
850        .region_routes
851        .iter()
852        .flat_map(|region_route| {
853            region_route
854                .follower_peers
855                .iter()
856                .map(|peer| NodeAddressKey::with_datanode(peer.id).to_bytes())
857                .chain(
858                    region_route
859                        .leader_peer
860                        .as_ref()
861                        .map(|leader| NodeAddressKey::with_datanode(leader.id).to_bytes()),
862                )
863        })
864        .collect()
865}
866
867#[cfg(test)]
868mod tests {
869    use std::sync::Arc;
870
871    use super::*;
872    use crate::kv_backend::memory::MemoryKvBackend;
873    use crate::kv_backend::{KvBackend, TxnService};
874    use crate::peer::Peer;
875    use crate::rpc::router::Region;
876    use crate::rpc::store::PutRequest;
877
878    #[test]
879    fn test_update_table_route_max_region_number() {
880        let table_route = PhysicalTableRouteValue::new(vec![
881            RegionRoute {
882                region: Region {
883                    id: RegionId::new(0, 1),
884                    ..Default::default()
885                },
886                ..Default::default()
887            },
888            RegionRoute {
889                region: Region {
890                    id: RegionId::new(0, 2),
891                    ..Default::default()
892                },
893                ..Default::default()
894            },
895        ]);
896        assert_eq!(table_route.max_region_number, 2);
897
898        // Shouldn't change the max region number.
899        let new_table_route = TableRouteValue::Physical(table_route)
900            .update(vec![RegionRoute {
901                region: Region {
902                    id: RegionId::new(0, 1),
903                    ..Default::default()
904                },
905                ..Default::default()
906            }])
907            .unwrap();
908        assert_eq!(
909            new_table_route
910                .as_physical_table_route_ref()
911                .max_region_number,
912            2
913        );
914
915        // Should increase the max region number.
916        let new_table_route = new_table_route
917            .update(vec![RegionRoute {
918                region: Region {
919                    id: RegionId::new(0, 3),
920                    ..Default::default()
921                },
922                ..Default::default()
923            }])
924            .unwrap()
925            .into_physical_table_route();
926        assert_eq!(new_table_route.max_region_number, 3);
927    }
928
929    #[test]
930    fn test_table_route_compatibility() {
931        let old_raw_v = r#"{"region_routes":[{"region":{"id":1,"name":"r1","partition":null,"attrs":{}},"leader_peer":{"id":2,"addr":"a2"},"follower_peers":[]},{"region":{"id":1,"name":"r1","partition":null,"attrs":{}},"leader_peer":{"id":2,"addr":"a2"},"follower_peers":[]}],"version":0}"#;
932        let v = TableRouteValue::try_from_raw_value(old_raw_v.as_bytes()).unwrap();
933
934        let expected_table_route = TableRouteValue::Physical(PhysicalTableRouteValue {
935            region_routes: vec![
936                RegionRoute {
937                    region: Region {
938                        id: RegionId::new(0, 1),
939                        name: "r1".to_string(),
940                        partition: None,
941                        attrs: Default::default(),
942                        partition_expr: Default::default(),
943                    },
944                    leader_peer: Some(Peer {
945                        id: 2,
946                        addr: "a2".to_string(),
947                    }),
948                    follower_peers: vec![],
949                    leader_state: None,
950                    leader_down_since: None,
951                },
952                RegionRoute {
953                    region: Region {
954                        id: RegionId::new(0, 1),
955                        name: "r1".to_string(),
956                        partition: None,
957                        attrs: Default::default(),
958                        partition_expr: Default::default(),
959                    },
960                    leader_peer: Some(Peer {
961                        id: 2,
962                        addr: "a2".to_string(),
963                    }),
964                    follower_peers: vec![],
965                    leader_state: None,
966                    leader_down_since: None,
967                },
968            ],
969            max_region_number: 1,
970            version: 0,
971        });
972
973        assert_eq!(v, expected_table_route);
974    }
975
976    #[test]
977    fn test_key_serialization() {
978        let key = TableRouteKey::new(42);
979        let raw_key = key.to_bytes();
980        assert_eq!(raw_key, b"__table_route/42");
981    }
982
983    #[test]
984    fn test_key_deserialization() {
985        let expected = TableRouteKey::new(42);
986        let key = TableRouteKey::from_bytes(b"__table_route/42").unwrap();
987        assert_eq!(key, expected);
988    }
989
990    #[tokio::test]
991    async fn test_table_route_storage_get_with_raw_bytes_empty() {
992        let kv = Arc::new(MemoryKvBackend::default());
993        let table_route_storage = TableRouteStorage::new(kv);
994        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
995        assert!(table_route.is_none());
996    }
997
998    #[tokio::test]
999    async fn test_table_route_storage_get_with_raw_bytes() {
1000        let kv = Arc::new(MemoryKvBackend::default());
1001        let table_route_storage = TableRouteStorage::new(kv.clone());
1002        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
1003        assert!(table_route.is_none());
1004        let table_route_manager = TableRouteManager::new(kv.clone());
1005        let table_route_value = TableRouteValue::Logical(LogicalTableRouteValue {
1006            physical_table_id: 1023,
1007        });
1008        let (txn, _) = table_route_manager
1009            .table_route_storage()
1010            .build_create_txn(1024, &table_route_value)
1011            .unwrap();
1012        let r = kv.txn(txn).await.unwrap();
1013        assert!(r.succeeded);
1014        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
1015        assert!(table_route.is_some());
1016        let got = table_route.unwrap().inner;
1017        assert_eq!(got, table_route_value);
1018    }
1019
1020    #[tokio::test]
1021    async fn test_table_route_batch_get() {
1022        let kv = Arc::new(MemoryKvBackend::default());
1023        let table_route_storage = TableRouteStorage::new(kv.clone());
1024        let routes = table_route_storage
1025            .batch_get(&[1023, 1024, 1025])
1026            .await
1027            .unwrap();
1028
1029        assert!(routes.iter().all(Option::is_none));
1030        let table_route_manager = TableRouteManager::new(kv.clone());
1031        let routes = [
1032            (
1033                1024,
1034                TableRouteValue::Logical(LogicalTableRouteValue {
1035                    physical_table_id: 1023,
1036                }),
1037            ),
1038            (
1039                1025,
1040                TableRouteValue::Logical(LogicalTableRouteValue {
1041                    physical_table_id: 1023,
1042                }),
1043            ),
1044        ];
1045        for (table_id, route) in &routes {
1046            let (txn, _) = table_route_manager
1047                .table_route_storage()
1048                .build_create_txn(*table_id, route)
1049                .unwrap();
1050            let r = kv.txn(txn).await.unwrap();
1051            assert!(r.succeeded);
1052        }
1053
1054        let results = table_route_storage
1055            .batch_get(&[9999, 1025, 8888, 1024])
1056            .await
1057            .unwrap();
1058        assert!(results[0].is_none());
1059        assert_eq!(results[1].as_ref().unwrap(), &routes[1].1);
1060        assert!(results[2].is_none());
1061        assert_eq!(results[3].as_ref().unwrap(), &routes[0].1);
1062    }
1063
1064    #[tokio::test]
1065    async fn remap_route_address_updates_addresses() {
1066        let kv = Arc::new(MemoryKvBackend::default());
1067        let table_route_storage = TableRouteStorage::new(kv.clone());
1068        let mut table_route = TableRouteValue::Physical(PhysicalTableRouteValue {
1069            region_routes: vec![RegionRoute {
1070                leader_peer: Some(Peer {
1071                    id: 1,
1072                    ..Default::default()
1073                }),
1074                follower_peers: vec![Peer {
1075                    id: 2,
1076                    ..Default::default()
1077                }],
1078                ..Default::default()
1079            }],
1080            max_region_number: 0,
1081            version: 0,
1082        });
1083
1084        kv.put(PutRequest {
1085            key: NodeAddressKey::with_datanode(1).to_bytes(),
1086            value: NodeAddressValue {
1087                peer: Peer {
1088                    addr: "addr1".to_string(),
1089                    ..Default::default()
1090                },
1091            }
1092            .try_as_raw_value()
1093            .unwrap(),
1094            ..Default::default()
1095        })
1096        .await
1097        .unwrap();
1098
1099        table_route_storage
1100            .remap_route_address(&mut table_route)
1101            .await
1102            .unwrap();
1103
1104        if let TableRouteValue::Physical(physical_table_route) = table_route {
1105            assert_eq!(
1106                physical_table_route.region_routes[0]
1107                    .leader_peer
1108                    .as_ref()
1109                    .unwrap()
1110                    .addr,
1111                "addr1"
1112            );
1113            assert_eq!(
1114                physical_table_route.region_routes[0].follower_peers[0].addr,
1115                ""
1116            );
1117        } else {
1118            panic!("Expected PhysicalTableRouteValue");
1119        }
1120    }
1121}