common_meta/key/
table_route.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{HashMap, HashSet};
16use std::fmt::Display;
17use std::sync::Arc;
18
19use serde::{Deserialize, Deserializer, Serialize};
20use snafu::{OptionExt, ResultExt, ensure};
21use store_api::storage::{RegionId, RegionNumber};
22use table::metadata::TableId;
23
24use crate::error::{
25    InvalidMetadataSnafu, MetadataCorruptionSnafu, RegionNotFoundSnafu, Result, SerdeJsonSnafu,
26    TableRouteNotFoundSnafu, UnexpectedLogicalRouteTableSnafu,
27};
28use crate::key::node_address::{NodeAddressKey, NodeAddressValue};
29use crate::key::txn_helper::TxnOpGetResponseSet;
30use crate::key::{
31    DeserializedValueWithBytes, MetadataKey, MetadataValue, RegionDistribution,
32    TABLE_ROUTE_KEY_PATTERN, TABLE_ROUTE_PREFIX,
33};
34use crate::kv_backend::KvBackendRef;
35use crate::kv_backend::txn::Txn;
36use crate::rpc::router::{RegionRoute, region_distribution};
37use crate::rpc::store::BatchGetRequest;
38
39/// The key stores table routes
40///
41/// The layout: `__table_route/{table_id}`.
42#[derive(Debug, PartialEq)]
43pub struct TableRouteKey {
44    pub table_id: TableId,
45}
46
47impl TableRouteKey {
48    pub fn new(table_id: TableId) -> Self {
49        Self { table_id }
50    }
51
52    /// Returns the range prefix of the table route key.
53    pub fn range_prefix() -> Vec<u8> {
54        format!("{}/", TABLE_ROUTE_PREFIX).into_bytes()
55    }
56}
57
58#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
59#[serde(tag = "type", rename_all = "snake_case")]
60pub enum TableRouteValue {
61    Physical(PhysicalTableRouteValue),
62    Logical(LogicalTableRouteValue),
63}
64
65#[derive(Debug, PartialEq, Serialize, Clone, Default)]
66pub struct PhysicalTableRouteValue {
67    // The region routes of the table.
68    pub region_routes: Vec<RegionRoute>,
69    // Tracks the highest region number ever allocated for the table.
70    // This value only increases: adding a region updates it if needed,
71    // and dropping regions does not decrease it.
72    pub max_region_number: RegionNumber,
73    // The version of the table route.
74    version: u64,
75}
76
77impl<'de> Deserialize<'de> for PhysicalTableRouteValue {
78    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        #[derive(Deserialize)]
83        struct Helper {
84            region_routes: Vec<RegionRoute>,
85            #[serde(default)]
86            max_region_number: Option<RegionNumber>,
87            version: u64,
88        }
89
90        let mut helper = Helper::deserialize(deserializer)?;
91        // If the max region number is not provided, we will calculate it from the region routes.
92        if helper.max_region_number.is_none() {
93            let max_region = helper
94                .region_routes
95                .iter()
96                .map(|r| r.region.id.region_number())
97                .max()
98                .unwrap_or_default();
99            helper.max_region_number = Some(max_region);
100        }
101
102        Ok(PhysicalTableRouteValue {
103            region_routes: helper.region_routes,
104            max_region_number: helper.max_region_number.unwrap_or_default(),
105            version: helper.version,
106        })
107    }
108}
109
110#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
111pub struct LogicalTableRouteValue {
112    physical_table_id: TableId,
113}
114
115impl TableRouteValue {
116    /// Returns a [TableRouteValue::Physical] if `table_id` equals `physical_table_id`.
117    /// Otherwise returns a [TableRouteValue::Logical].
118    pub(crate) fn new(
119        table_id: TableId,
120        physical_table_id: TableId,
121        region_routes: Vec<RegionRoute>,
122    ) -> Self {
123        if table_id == physical_table_id {
124            TableRouteValue::physical(region_routes)
125        } else {
126            TableRouteValue::logical(physical_table_id)
127        }
128    }
129
130    pub fn physical(region_routes: Vec<RegionRoute>) -> Self {
131        Self::Physical(PhysicalTableRouteValue::new(region_routes))
132    }
133
134    pub fn logical(physical_table_id: TableId) -> Self {
135        Self::Logical(LogicalTableRouteValue::new(physical_table_id))
136    }
137
138    /// Returns a new version [TableRouteValue] with `region_routes`.
139    pub fn update(&self, region_routes: Vec<RegionRoute>) -> Result<Self> {
140        ensure!(
141            self.is_physical(),
142            UnexpectedLogicalRouteTableSnafu {
143                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
144            }
145        );
146        let physical_table_route = self.as_physical_table_route_ref();
147        let original_max_region_number = physical_table_route.max_region_number;
148        let new_max_region_number = region_routes
149            .iter()
150            .map(|r| r.region.id.region_number())
151            .max()
152            .unwrap_or_default();
153        let version = physical_table_route.version;
154        Ok(Self::Physical(PhysicalTableRouteValue {
155            region_routes,
156            // If region routes are added, we will update the max region number.
157            // If region routes are removed, we will keep the original max region number.
158            max_region_number: original_max_region_number.max(new_max_region_number),
159            version: version + 1,
160        }))
161    }
162
163    /// Returns the version.
164    ///
165    /// For test purpose.
166    #[cfg(any(test, feature = "testing"))]
167    pub fn version(&self) -> Result<u64> {
168        ensure!(
169            self.is_physical(),
170            UnexpectedLogicalRouteTableSnafu {
171                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
172            }
173        );
174        Ok(self.as_physical_table_route_ref().version)
175    }
176
177    /// Returns the corresponding [RegionRoute], returns `None` if it's the specific region is not found.
178    ///
179    /// Note: It throws an error if it's a logical table
180    pub fn region_route(&self, region_id: RegionId) -> Result<Option<RegionRoute>> {
181        ensure!(
182            self.is_physical(),
183            UnexpectedLogicalRouteTableSnafu {
184                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
185            }
186        );
187        Ok(self
188            .as_physical_table_route_ref()
189            .region_routes
190            .iter()
191            .find(|route| route.region.id == region_id)
192            .cloned())
193    }
194
195    /// Returns true if it's [TableRouteValue::Physical].
196    pub fn is_physical(&self) -> bool {
197        matches!(self, TableRouteValue::Physical(_))
198    }
199
200    /// Gets the [RegionRoute]s of this [TableRouteValue::Physical].
201    pub fn region_routes(&self) -> Result<&Vec<RegionRoute>> {
202        ensure!(
203            self.is_physical(),
204            UnexpectedLogicalRouteTableSnafu {
205                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
206            }
207        );
208        Ok(&self.as_physical_table_route_ref().region_routes)
209    }
210
211    /// Returns the max region number of this [TableRouteValue::Physical].
212    ///
213    /// # Panic
214    /// If it is not the [`PhysicalTableRouteValue`].
215    pub fn max_region_number(&self) -> Result<RegionNumber> {
216        ensure!(
217            self.is_physical(),
218            UnexpectedLogicalRouteTableSnafu {
219                err_msg: format!("{self:?} is a non-physical TableRouteValue."),
220            }
221        );
222        Ok(self.as_physical_table_route_ref().max_region_number)
223    }
224
225    /// Returns the reference of [`PhysicalTableRouteValue`].
226    ///
227    /// # Panic
228    /// If it is not the [`PhysicalTableRouteValue`].
229    fn as_physical_table_route_ref(&self) -> &PhysicalTableRouteValue {
230        match self {
231            TableRouteValue::Physical(x) => x,
232            _ => unreachable!("Mistakenly been treated as a Physical TableRoute: {self:?}"),
233        }
234    }
235
236    /// Converts to [`PhysicalTableRouteValue`].
237    ///
238    /// # Panic
239    /// If it is not the [`PhysicalTableRouteValue`].
240    pub fn into_physical_table_route(self) -> PhysicalTableRouteValue {
241        match self {
242            TableRouteValue::Physical(x) => x,
243            _ => unreachable!("Mistakenly been treated as a Physical TableRoute: {self:?}"),
244        }
245    }
246
247    /// Converts to [`LogicalTableRouteValue`].
248    ///
249    /// # Panic
250    /// If it is not the [`LogicalTableRouteValue`].
251    pub fn into_logical_table_route(self) -> LogicalTableRouteValue {
252        match self {
253            TableRouteValue::Logical(x) => x,
254            _ => unreachable!("Mistakenly been treated as a Logical TableRoute: {self:?}"),
255        }
256    }
257
258    pub fn region_numbers(&self) -> Vec<RegionNumber> {
259        match self {
260            TableRouteValue::Physical(x) => x
261                .region_routes
262                .iter()
263                .map(|region_route| region_route.region.id.region_number())
264                .collect(),
265            TableRouteValue::Logical(_) => {
266                vec![]
267            }
268        }
269    }
270}
271
272impl MetadataValue for TableRouteValue {
273    fn try_from_raw_value(raw_value: &[u8]) -> Result<Self> {
274        let r = serde_json::from_slice::<TableRouteValue>(raw_value);
275        match r {
276            // Compatible with old TableRouteValue.
277            Err(e) if e.is_data() => Ok(Self::Physical(
278                serde_json::from_slice::<PhysicalTableRouteValue>(raw_value)
279                    .context(SerdeJsonSnafu)?,
280            )),
281            Ok(x) => Ok(x),
282            Err(e) => Err(e).context(SerdeJsonSnafu),
283        }
284    }
285
286    fn try_as_raw_value(&self) -> Result<Vec<u8>> {
287        serde_json::to_vec(self).context(SerdeJsonSnafu)
288    }
289}
290
291impl PhysicalTableRouteValue {
292    /// Returns the version of the table route.
293    pub fn version(&self) -> u64 {
294        self.version
295    }
296    pub fn new(region_routes: Vec<RegionRoute>) -> Self {
297        let max_region_number = region_routes
298            .iter()
299            .map(|r| r.region.id.region_number())
300            .max()
301            .unwrap_or_default();
302        Self {
303            region_routes,
304            max_region_number,
305            version: 0,
306        }
307    }
308}
309
310impl LogicalTableRouteValue {
311    pub fn new(physical_table_id: TableId) -> Self {
312        Self { physical_table_id }
313    }
314
315    pub fn physical_table_id(&self) -> TableId {
316        self.physical_table_id
317    }
318}
319
320impl MetadataKey<'_, TableRouteKey> for TableRouteKey {
321    fn to_bytes(&self) -> Vec<u8> {
322        self.to_string().into_bytes()
323    }
324
325    fn from_bytes(bytes: &[u8]) -> Result<TableRouteKey> {
326        let key = std::str::from_utf8(bytes).map_err(|e| {
327            InvalidMetadataSnafu {
328                err_msg: format!(
329                    "TableRouteKey '{}' is not a valid UTF8 string: {e}",
330                    String::from_utf8_lossy(bytes)
331                ),
332            }
333            .build()
334        })?;
335        let captures = TABLE_ROUTE_KEY_PATTERN
336            .captures(key)
337            .context(InvalidMetadataSnafu {
338                err_msg: format!("Invalid TableRouteKey '{key}'"),
339            })?;
340        // Safety: pass the regex check above
341        let table_id = captures[1].parse::<TableId>().unwrap();
342        Ok(TableRouteKey { table_id })
343    }
344}
345
346impl Display for TableRouteKey {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        write!(f, "{}/{}", TABLE_ROUTE_PREFIX, self.table_id)
349    }
350}
351
352pub type TableRouteManagerRef = Arc<TableRouteManager>;
353
354pub struct TableRouteManager {
355    storage: TableRouteStorage,
356}
357
358impl TableRouteManager {
359    pub fn new(kv_backend: KvBackendRef) -> Self {
360        Self {
361            storage: TableRouteStorage::new(kv_backend),
362        }
363    }
364
365    /// Returns the [TableId] recursively.
366    ///
367    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
368    /// - the table(`logical_or_physical_table_id`) does not exist.
369    pub async fn get_physical_table_id(
370        &self,
371        logical_or_physical_table_id: TableId,
372    ) -> Result<TableId> {
373        let table_route = self
374            .storage
375            .get_inner(logical_or_physical_table_id)
376            .await?
377            .context(TableRouteNotFoundSnafu {
378                table_id: logical_or_physical_table_id,
379            })?;
380
381        match table_route {
382            TableRouteValue::Physical(_) => Ok(logical_or_physical_table_id),
383            TableRouteValue::Logical(x) => Ok(x.physical_table_id()),
384        }
385    }
386
387    /// Returns the [TableRouteValue::Physical] recursively.
388    ///
389    /// Returns a [TableRouteNotFound](error::Error::TableRouteNotFound) Error if:
390    /// - the physical table(`logical_or_physical_table_id`) does not exist
391    /// - the corresponding physical table of the logical table(`logical_or_physical_table_id`) does not exist.
392    pub async fn get_physical_table_route(
393        &self,
394        logical_or_physical_table_id: TableId,
395    ) -> Result<(TableId, PhysicalTableRouteValue)> {
396        let table_route = self
397            .storage
398            .get(logical_or_physical_table_id)
399            .await?
400            .context(TableRouteNotFoundSnafu {
401                table_id: logical_or_physical_table_id,
402            })?;
403
404        match table_route {
405            TableRouteValue::Physical(x) => Ok((logical_or_physical_table_id, x)),
406            TableRouteValue::Logical(x) => {
407                let physical_table_id = x.physical_table_id();
408                let physical_table_route = self.storage.get(physical_table_id).await?.context(
409                    TableRouteNotFoundSnafu {
410                        table_id: physical_table_id,
411                    },
412                )?;
413                let physical_table_route = physical_table_route.into_physical_table_route();
414                Ok((physical_table_id, physical_table_route))
415            }
416        }
417    }
418
419    /// Returns the [TableRouteValue::Physical] recursively.
420    ///
421    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
422    /// - one of the logical tables corresponding to the physical table does not exist.
423    ///
424    /// **Notes**: it may return a subset of `logical_or_physical_table_ids`.
425    pub async fn batch_get_physical_table_routes(
426        &self,
427        logical_or_physical_table_ids: &[TableId],
428    ) -> Result<HashMap<TableId, PhysicalTableRouteValue>> {
429        let table_routes = self
430            .storage
431            .batch_get(logical_or_physical_table_ids)
432            .await?;
433        // Returns a subset of `logical_or_physical_table_ids`.
434        let table_routes = table_routes
435            .into_iter()
436            .zip(logical_or_physical_table_ids)
437            .filter_map(|(route, id)| route.map(|route| (*id, route)))
438            .collect::<HashMap<_, _>>();
439
440        let mut physical_table_routes = HashMap::with_capacity(table_routes.len());
441        let mut logical_table_ids = HashMap::with_capacity(table_routes.len());
442
443        for (table_id, table_route) in table_routes {
444            match table_route {
445                TableRouteValue::Physical(x) => {
446                    physical_table_routes.insert(table_id, x);
447                }
448                TableRouteValue::Logical(x) => {
449                    logical_table_ids.insert(table_id, x.physical_table_id());
450                }
451            }
452        }
453
454        if logical_table_ids.is_empty() {
455            return Ok(physical_table_routes);
456        }
457
458        // Finds the logical tables corresponding to the physical tables.
459        let physical_table_ids = logical_table_ids
460            .values()
461            .cloned()
462            .collect::<HashSet<_>>()
463            .into_iter()
464            .collect::<Vec<_>>();
465        let table_routes = self
466            .table_route_storage()
467            .batch_get(&physical_table_ids)
468            .await?;
469        let table_routes = table_routes
470            .into_iter()
471            .zip(physical_table_ids)
472            .filter_map(|(route, id)| route.map(|route| (id, route)))
473            .collect::<HashMap<_, _>>();
474
475        for (logical_table_id, physical_table_id) in logical_table_ids {
476            let table_route =
477                table_routes
478                    .get(&physical_table_id)
479                    .context(TableRouteNotFoundSnafu {
480                        table_id: physical_table_id,
481                    })?;
482            match table_route {
483                TableRouteValue::Physical(x) => {
484                    physical_table_routes.insert(logical_table_id, x.clone());
485                }
486                TableRouteValue::Logical(x) => {
487                    // Never get here, because we use a physical table id cannot obtain a logical table.
488                    MetadataCorruptionSnafu {
489                        err_msg: format!(
490                            "logical table {} {:?} cannot be resolved to a physical table.",
491                            logical_table_id, x
492                        ),
493                    }
494                    .fail()?;
495                }
496            }
497        }
498
499        Ok(physical_table_routes)
500    }
501
502    /// Returns [`RegionDistribution`] of the table(`table_id`).
503    pub async fn get_region_distribution(
504        &self,
505        table_id: TableId,
506    ) -> Result<Option<RegionDistribution>> {
507        self.storage
508            .get(table_id)
509            .await?
510            .map(|table_route| Ok(region_distribution(table_route.region_routes()?)))
511            .transpose()
512    }
513
514    /// Sets the staging state for a specific region.
515    ///
516    /// Returns a [TableRouteNotFound](crate::error::Error::TableRouteNotFound) Error if:
517    /// - the table does not exist
518    /// - the region is not found in the table
519    pub async fn set_region_staging_state(
520        &self,
521        region_id: store_api::storage::RegionId,
522        staging: bool,
523    ) -> Result<()> {
524        let table_id = region_id.table_id();
525
526        // Get current table route with raw bytes for CAS operation
527        let current_table_route = self
528            .storage
529            .get_with_raw_bytes(table_id)
530            .await?
531            .context(TableRouteNotFoundSnafu { table_id })?;
532
533        // Clone the current route value and update the specific region
534        let new_table_route = current_table_route.inner.clone();
535
536        // Only physical tables have region routes
537        ensure!(
538            new_table_route.is_physical(),
539            UnexpectedLogicalRouteTableSnafu {
540                err_msg: format!("Cannot set staging state for logical table {table_id}"),
541            }
542        );
543
544        let region_routes = new_table_route.region_routes()?.clone();
545        let mut updated_routes = region_routes.clone();
546
547        // Find and update the specific region
548        // TODO(ruihang): maybe update them in one transaction
549        let mut region_found = false;
550        for route in &mut updated_routes {
551            if route.region.id == region_id {
552                if staging {
553                    route.set_leader_staging();
554                } else {
555                    route.clear_leader_staging();
556                }
557                region_found = true;
558                break;
559            }
560        }
561
562        ensure!(region_found, RegionNotFoundSnafu { region_id });
563
564        // Create new table route with updated region routes
565        let updated_table_route = new_table_route.update(updated_routes)?;
566
567        // Execute atomic update
568        let (txn, _) =
569            self.storage
570                .build_update_txn(table_id, &current_table_route, &updated_table_route)?;
571
572        let result = self.storage.kv_backend.txn(txn).await?;
573
574        ensure!(
575            result.succeeded,
576            MetadataCorruptionSnafu {
577                err_msg: format!(
578                    "Failed to update staging state for region {}: CAS operation failed",
579                    region_id
580                ),
581            }
582        );
583
584        Ok(())
585    }
586
587    /// Checks if a specific region is in staging state.
588    ///
589    /// Returns false if the table/region doesn't exist.
590    pub async fn is_region_staging(&self, region_id: store_api::storage::RegionId) -> Result<bool> {
591        let table_id = region_id.table_id();
592
593        let table_route = self.storage.get(table_id).await?;
594
595        match table_route {
596            Some(route) if route.is_physical() => {
597                let region_routes = route.region_routes()?;
598                for route in region_routes {
599                    if route.region.id == region_id {
600                        return Ok(route.is_leader_staging());
601                    }
602                }
603                Ok(false)
604            }
605            _ => Ok(false),
606        }
607    }
608
609    /// Returns low-level APIs.
610    pub fn table_route_storage(&self) -> &TableRouteStorage {
611        &self.storage
612    }
613}
614
615/// Low-level operations of [TableRouteValue].
616pub struct TableRouteStorage {
617    kv_backend: KvBackendRef,
618}
619
620pub type TableRouteValueDecodeResult = Result<Option<DeserializedValueWithBytes<TableRouteValue>>>;
621
622impl TableRouteStorage {
623    pub fn new(kv_backend: KvBackendRef) -> Self {
624        Self { kv_backend }
625    }
626
627    /// Builds a create table route transaction,
628    /// it expected the `__table_route/{table_id}` wasn't occupied.
629    pub fn build_create_txn(
630        &self,
631        table_id: TableId,
632        table_route_value: &TableRouteValue,
633    ) -> Result<(
634        Txn,
635        impl FnOnce(&mut TxnOpGetResponseSet) -> TableRouteValueDecodeResult + use<>,
636    )> {
637        let key = TableRouteKey::new(table_id);
638        let raw_key = key.to_bytes();
639
640        let txn = Txn::put_if_not_exists(raw_key.clone(), table_route_value.try_as_raw_value()?);
641
642        Ok((
643            txn,
644            TxnOpGetResponseSet::decode_with(TxnOpGetResponseSet::filter(raw_key)),
645        ))
646    }
647
648    // TODO(LFC): restore its original visibility after some test utility codes are refined
649    /// Builds a update table route transaction,
650    /// it expected the remote value equals the `current_table_route_value`.
651    /// It retrieves the latest value if the comparing failed.
652    pub fn build_update_txn(
653        &self,
654        table_id: TableId,
655        current_table_route_value: &DeserializedValueWithBytes<TableRouteValue>,
656        new_table_route_value: &TableRouteValue,
657    ) -> Result<(
658        Txn,
659        impl FnOnce(&mut TxnOpGetResponseSet) -> TableRouteValueDecodeResult + use<>,
660    )> {
661        let key = TableRouteKey::new(table_id);
662        let raw_key = key.to_bytes();
663        let raw_value = current_table_route_value.get_raw_bytes();
664        let new_raw_value: Vec<u8> = new_table_route_value.try_as_raw_value()?;
665
666        let txn = Txn::compare_and_put(raw_key.clone(), raw_value, new_raw_value);
667
668        Ok((
669            txn,
670            TxnOpGetResponseSet::decode_with(TxnOpGetResponseSet::filter(raw_key)),
671        ))
672    }
673
674    /// Returns the [`TableRouteValue`].
675    pub async fn get(&self, table_id: TableId) -> Result<Option<TableRouteValue>> {
676        let mut table_route = self.get_inner(table_id).await?;
677        if let Some(table_route) = &mut table_route {
678            self.remap_route_address(table_route).await?;
679        };
680
681        Ok(table_route)
682    }
683
684    async fn get_inner(&self, table_id: TableId) -> Result<Option<TableRouteValue>> {
685        let key = TableRouteKey::new(table_id);
686        self.kv_backend
687            .get(&key.to_bytes())
688            .await?
689            .map(|kv| TableRouteValue::try_from_raw_value(&kv.value))
690            .transpose()
691    }
692
693    /// Returns the [`TableRouteValue`] wrapped with [`DeserializedValueWithBytes`].
694    pub async fn get_with_raw_bytes(
695        &self,
696        table_id: TableId,
697    ) -> Result<Option<DeserializedValueWithBytes<TableRouteValue>>> {
698        let mut table_route = self.get_with_raw_bytes_inner(table_id).await?;
699        if let Some(table_route) = &mut table_route {
700            self.remap_route_address(table_route).await?;
701        };
702
703        Ok(table_route)
704    }
705
706    async fn get_with_raw_bytes_inner(
707        &self,
708        table_id: TableId,
709    ) -> Result<Option<DeserializedValueWithBytes<TableRouteValue>>> {
710        let key = TableRouteKey::new(table_id);
711        self.kv_backend
712            .get(&key.to_bytes())
713            .await?
714            .map(|kv| DeserializedValueWithBytes::from_inner_slice(&kv.value))
715            .transpose()
716    }
717
718    /// Returns batch of [`TableRouteValue`] that respects the order of `table_ids`.
719    pub async fn batch_get(&self, table_ids: &[TableId]) -> Result<Vec<Option<TableRouteValue>>> {
720        let raw_table_routes = self.batch_get_inner(table_ids).await?;
721
722        Ok(raw_table_routes
723            .into_iter()
724            .map(|v| v.map(|x| x.inner))
725            .collect())
726    }
727
728    /// Returns batch of [`TableRouteValue`] wrapped with [`DeserializedValueWithBytes`].
729    ///
730    /// The return value is a vector of [`Option<DeserializedValueWithBytes<TableRouteValue>>`].
731    /// Note: This method remaps the addresses of the table routes, but does not update their raw byte representations.
732    pub async fn batch_get_with_raw_bytes(
733        &self,
734        table_ids: &[TableId],
735    ) -> Result<Vec<Option<DeserializedValueWithBytes<TableRouteValue>>>> {
736        let mut raw_table_routes = self.batch_get_inner(table_ids).await?;
737        self.remap_routes_addresses(&mut raw_table_routes).await?;
738
739        Ok(raw_table_routes)
740    }
741
742    async fn batch_get_inner(
743        &self,
744        table_ids: &[TableId],
745    ) -> Result<Vec<Option<DeserializedValueWithBytes<TableRouteValue>>>> {
746        let keys = table_ids
747            .iter()
748            .map(|id| TableRouteKey::new(*id).to_bytes())
749            .collect::<Vec<_>>();
750        let resp = self
751            .kv_backend
752            .batch_get(BatchGetRequest { keys: keys.clone() })
753            .await?;
754
755        let kvs = resp
756            .kvs
757            .into_iter()
758            .map(|kv| (kv.key, kv.value))
759            .collect::<HashMap<_, _>>();
760        keys.into_iter()
761            .map(|key| {
762                if let Some(value) = kvs.get(&key) {
763                    Ok(Some(DeserializedValueWithBytes::from_inner_slice(value)?))
764                } else {
765                    Ok(None)
766                }
767            })
768            .collect()
769    }
770
771    async fn remap_routes_addresses(
772        &self,
773        table_routes: &mut [Option<DeserializedValueWithBytes<TableRouteValue>>],
774    ) -> Result<()> {
775        let keys = table_routes
776            .iter()
777            .flat_map(|table_route| {
778                table_route
779                    .as_ref()
780                    .map(|x| extract_address_keys(&x.inner))
781                    .unwrap_or_default()
782            })
783            .collect::<HashSet<_>>()
784            .into_iter()
785            .collect();
786        let node_addrs = self.get_node_addresses(keys).await?;
787        for table_route in table_routes.iter_mut().flatten() {
788            set_addresses(&node_addrs, table_route)?;
789        }
790
791        Ok(())
792    }
793
794    pub(crate) async fn remap_route_address(
795        &self,
796        table_route: &mut TableRouteValue,
797    ) -> Result<()> {
798        let keys = extract_address_keys(table_route).into_iter().collect();
799        let node_addrs = self.get_node_addresses(keys).await?;
800        set_addresses(&node_addrs, table_route)?;
801
802        Ok(())
803    }
804
805    async fn get_node_addresses(
806        &self,
807        keys: Vec<Vec<u8>>,
808    ) -> Result<HashMap<u64, NodeAddressValue>> {
809        if keys.is_empty() {
810            return Ok(HashMap::default());
811        }
812
813        self.kv_backend
814            .batch_get(BatchGetRequest { keys })
815            .await?
816            .kvs
817            .into_iter()
818            .map(|kv| {
819                let node_id = NodeAddressKey::from_bytes(&kv.key)?.node_id;
820                let node_addr = NodeAddressValue::try_from_raw_value(&kv.value)?;
821                Ok((node_id, node_addr))
822            })
823            .collect()
824    }
825}
826
827fn set_addresses(
828    node_addrs: &HashMap<u64, NodeAddressValue>,
829    table_route: &mut TableRouteValue,
830) -> Result<()> {
831    let TableRouteValue::Physical(physical_table_route) = table_route else {
832        return Ok(());
833    };
834
835    for region_route in &mut physical_table_route.region_routes {
836        if let Some(leader) = &mut region_route.leader_peer
837            && let Some(node_addr) = node_addrs.get(&leader.id)
838        {
839            leader.addr = node_addr.peer.addr.clone();
840        }
841        for follower in &mut region_route.follower_peers {
842            if let Some(node_addr) = node_addrs.get(&follower.id) {
843                follower.addr = node_addr.peer.addr.clone();
844            }
845        }
846    }
847
848    Ok(())
849}
850
851fn extract_address_keys(table_route: &TableRouteValue) -> HashSet<Vec<u8>> {
852    let TableRouteValue::Physical(physical_table_route) = table_route else {
853        return HashSet::default();
854    };
855
856    physical_table_route
857        .region_routes
858        .iter()
859        .flat_map(|region_route| {
860            region_route
861                .follower_peers
862                .iter()
863                .map(|peer| NodeAddressKey::with_datanode(peer.id).to_bytes())
864                .chain(
865                    region_route
866                        .leader_peer
867                        .as_ref()
868                        .map(|leader| NodeAddressKey::with_datanode(leader.id).to_bytes()),
869                )
870        })
871        .collect()
872}
873
874#[cfg(test)]
875mod tests {
876    use std::sync::Arc;
877
878    use super::*;
879    use crate::kv_backend::memory::MemoryKvBackend;
880    use crate::kv_backend::{KvBackend, TxnService};
881    use crate::peer::Peer;
882    use crate::rpc::router::Region;
883    use crate::rpc::store::PutRequest;
884
885    #[test]
886    fn test_update_table_route_max_region_number() {
887        let table_route = PhysicalTableRouteValue::new(vec![
888            RegionRoute {
889                region: Region {
890                    id: RegionId::new(0, 1),
891                    ..Default::default()
892                },
893                ..Default::default()
894            },
895            RegionRoute {
896                region: Region {
897                    id: RegionId::new(0, 2),
898                    ..Default::default()
899                },
900                ..Default::default()
901            },
902        ]);
903        assert_eq!(table_route.max_region_number, 2);
904
905        // Shouldn't change the max region number.
906        let new_table_route = TableRouteValue::Physical(table_route)
907            .update(vec![RegionRoute {
908                region: Region {
909                    id: RegionId::new(0, 1),
910                    ..Default::default()
911                },
912                ..Default::default()
913            }])
914            .unwrap();
915        assert_eq!(
916            new_table_route
917                .as_physical_table_route_ref()
918                .max_region_number,
919            2
920        );
921
922        // Should increase the max region number.
923        let new_table_route = new_table_route
924            .update(vec![RegionRoute {
925                region: Region {
926                    id: RegionId::new(0, 3),
927                    ..Default::default()
928                },
929                ..Default::default()
930            }])
931            .unwrap()
932            .into_physical_table_route();
933        assert_eq!(new_table_route.max_region_number, 3);
934    }
935
936    #[test]
937    fn test_table_route_compatibility() {
938        let old_raw_v = r#"{"region_routes":[{"region":{"id":1,"name":"r1","partition":null,"attrs":{}},"leader_peer":{"id":2,"addr":"a2"},"follower_peers":[]},{"region":{"id":1,"name":"r1","partition":null,"attrs":{}},"leader_peer":{"id":2,"addr":"a2"},"follower_peers":[]}],"version":0}"#;
939        let v = TableRouteValue::try_from_raw_value(old_raw_v.as_bytes()).unwrap();
940
941        let expected_table_route = TableRouteValue::Physical(PhysicalTableRouteValue {
942            region_routes: vec![
943                RegionRoute {
944                    region: Region {
945                        id: RegionId::new(0, 1),
946                        name: "r1".to_string(),
947                        attrs: Default::default(),
948                        partition_expr: Default::default(),
949                    },
950                    leader_peer: Some(Peer {
951                        id: 2,
952                        addr: "a2".to_string(),
953                    }),
954                    follower_peers: vec![],
955                    leader_state: None,
956                    leader_down_since: None,
957                    write_route_policy: None,
958                },
959                RegionRoute {
960                    region: Region {
961                        id: RegionId::new(0, 1),
962                        name: "r1".to_string(),
963                        attrs: Default::default(),
964                        partition_expr: Default::default(),
965                    },
966                    leader_peer: Some(Peer {
967                        id: 2,
968                        addr: "a2".to_string(),
969                    }),
970                    follower_peers: vec![],
971                    leader_state: None,
972                    leader_down_since: None,
973                    write_route_policy: None,
974                },
975            ],
976            max_region_number: 1,
977            version: 0,
978        });
979
980        assert_eq!(v, expected_table_route);
981    }
982
983    #[test]
984    fn test_key_serialization() {
985        let key = TableRouteKey::new(42);
986        let raw_key = key.to_bytes();
987        assert_eq!(raw_key, b"__table_route/42");
988    }
989
990    #[test]
991    fn test_key_deserialization() {
992        let expected = TableRouteKey::new(42);
993        let key = TableRouteKey::from_bytes(b"__table_route/42").unwrap();
994        assert_eq!(key, expected);
995    }
996
997    #[tokio::test]
998    async fn test_table_route_storage_get_with_raw_bytes_empty() {
999        let kv = Arc::new(MemoryKvBackend::default());
1000        let table_route_storage = TableRouteStorage::new(kv);
1001        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
1002        assert!(table_route.is_none());
1003    }
1004
1005    #[tokio::test]
1006    async fn test_table_route_storage_get_with_raw_bytes() {
1007        let kv = Arc::new(MemoryKvBackend::default());
1008        let table_route_storage = TableRouteStorage::new(kv.clone());
1009        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
1010        assert!(table_route.is_none());
1011        let table_route_manager = TableRouteManager::new(kv.clone());
1012        let table_route_value = TableRouteValue::Logical(LogicalTableRouteValue {
1013            physical_table_id: 1023,
1014        });
1015        let (txn, _) = table_route_manager
1016            .table_route_storage()
1017            .build_create_txn(1024, &table_route_value)
1018            .unwrap();
1019        let r = kv.txn(txn).await.unwrap();
1020        assert!(r.succeeded);
1021        let table_route = table_route_storage.get_with_raw_bytes(1024).await.unwrap();
1022        assert!(table_route.is_some());
1023        let got = table_route.unwrap().inner;
1024        assert_eq!(got, table_route_value);
1025    }
1026
1027    #[tokio::test]
1028    async fn test_table_route_batch_get() {
1029        let kv = Arc::new(MemoryKvBackend::default());
1030        let table_route_storage = TableRouteStorage::new(kv.clone());
1031        let routes = table_route_storage
1032            .batch_get(&[1023, 1024, 1025])
1033            .await
1034            .unwrap();
1035
1036        assert!(routes.iter().all(Option::is_none));
1037        let table_route_manager = TableRouteManager::new(kv.clone());
1038        let routes = [
1039            (
1040                1024,
1041                TableRouteValue::Logical(LogicalTableRouteValue {
1042                    physical_table_id: 1023,
1043                }),
1044            ),
1045            (
1046                1025,
1047                TableRouteValue::Logical(LogicalTableRouteValue {
1048                    physical_table_id: 1023,
1049                }),
1050            ),
1051        ];
1052        for (table_id, route) in &routes {
1053            let (txn, _) = table_route_manager
1054                .table_route_storage()
1055                .build_create_txn(*table_id, route)
1056                .unwrap();
1057            let r = kv.txn(txn).await.unwrap();
1058            assert!(r.succeeded);
1059        }
1060
1061        let results = table_route_storage
1062            .batch_get(&[9999, 1025, 8888, 1024])
1063            .await
1064            .unwrap();
1065        assert!(results[0].is_none());
1066        assert_eq!(results[1].as_ref().unwrap(), &routes[1].1);
1067        assert!(results[2].is_none());
1068        assert_eq!(results[3].as_ref().unwrap(), &routes[0].1);
1069    }
1070
1071    #[tokio::test]
1072    async fn remap_route_address_updates_addresses() {
1073        let kv = Arc::new(MemoryKvBackend::default());
1074        let table_route_storage = TableRouteStorage::new(kv.clone());
1075        let mut table_route = TableRouteValue::Physical(PhysicalTableRouteValue {
1076            region_routes: vec![RegionRoute {
1077                leader_peer: Some(Peer {
1078                    id: 1,
1079                    ..Default::default()
1080                }),
1081                follower_peers: vec![Peer {
1082                    id: 2,
1083                    ..Default::default()
1084                }],
1085                ..Default::default()
1086            }],
1087            max_region_number: 0,
1088            version: 0,
1089        });
1090
1091        kv.put(PutRequest {
1092            key: NodeAddressKey::with_datanode(1).to_bytes(),
1093            value: NodeAddressValue {
1094                peer: Peer {
1095                    addr: "addr1".to_string(),
1096                    ..Default::default()
1097                },
1098            }
1099            .try_as_raw_value()
1100            .unwrap(),
1101            ..Default::default()
1102        })
1103        .await
1104        .unwrap();
1105
1106        table_route_storage
1107            .remap_route_address(&mut table_route)
1108            .await
1109            .unwrap();
1110
1111        if let TableRouteValue::Physical(physical_table_route) = table_route {
1112            assert_eq!(
1113                physical_table_route.region_routes[0]
1114                    .leader_peer
1115                    .as_ref()
1116                    .unwrap()
1117                    .addr,
1118                "addr1"
1119            );
1120            assert_eq!(
1121                physical_table_route.region_routes[0].follower_peers[0].addr,
1122                ""
1123            );
1124        } else {
1125            panic!("Expected PhysicalTableRouteValue");
1126        }
1127    }
1128}