common_function/scalars/
primary_key.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{self, Display};
17use std::sync::Arc;
18
19use datafusion_common::arrow::array::{
20    Array, ArrayRef, BinaryArray, BinaryViewArray, DictionaryArray, ListBuilder, StringBuilder,
21};
22use datafusion_common::arrow::datatypes::{DataType, Field};
23use datafusion_common::{DataFusionError, ScalarValue};
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
25use datatypes::arrow::datatypes::UInt32Type;
26use datatypes::value::Value;
27use mito_codec::row_converter::{
28    CompositeValues, PrimaryKeyCodec, SortField, build_primary_key_codec_with_fields,
29};
30use store_api::codec::PrimaryKeyEncoding;
31use store_api::metadata::RegionMetadata;
32use store_api::storage::ColumnId;
33use store_api::storage::consts::{PRIMARY_KEY_COLUMN_NAME, ReservedColumnId};
34
35use crate::function::{Function, extract_args};
36use crate::function_registry::FunctionRegistry;
37
38type NameValuePair = (String, Option<String>);
39
40#[derive(Clone, Debug)]
41pub(crate) struct DecodePrimaryKeyFunction {
42    signature: Signature,
43}
44
45const NAME: &str = "decode_primary_key";
46const NULL_VALUE_LITERAL: &str = "null";
47
48impl Default for DecodePrimaryKeyFunction {
49    fn default() -> Self {
50        Self {
51            signature: Signature::any(3, Volatility::Immutable),
52        }
53    }
54}
55
56impl DecodePrimaryKeyFunction {
57    pub fn register(registry: &FunctionRegistry) {
58        registry.register_scalar(Self::default());
59    }
60
61    fn return_data_type() -> DataType {
62        DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))
63    }
64}
65
66impl Function for DecodePrimaryKeyFunction {
67    fn name(&self) -> &str {
68        NAME
69    }
70
71    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
72        Ok(Self::return_data_type())
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn invoke_with_args(
80        &self,
81        args: ScalarFunctionArgs,
82    ) -> datafusion_common::Result<ColumnarValue> {
83        let [encoded, _, _] = extract_args(self.name(), &args)?;
84        let number_rows = args.number_rows;
85
86        let encoding = parse_encoding(&args.args[1])?;
87        let metadata = parse_region_metadata(&args.args[2])?;
88        let codec = build_codec(&metadata, encoding);
89        let name_lookup: HashMap<_, _> = metadata
90            .column_metadatas
91            .iter()
92            .map(|c| (c.column_id, c.column_schema.name.clone()))
93            .collect();
94
95        let decoded_rows = decode_primary_keys(encoded, number_rows, codec.as_ref(), &name_lookup)?;
96        let array = build_list_array(&decoded_rows)?;
97
98        Ok(ColumnarValue::Array(array))
99    }
100}
101
102impl Display for DecodePrimaryKeyFunction {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        write!(f, "DECODE_PRIMARY_KEY")
105    }
106}
107
108fn parse_encoding(arg: &ColumnarValue) -> datafusion_common::Result<PrimaryKeyEncoding> {
109    let encoding = match arg {
110        ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
111        | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => v.as_str(),
112        ColumnarValue::Scalar(value) => {
113            return Err(DataFusionError::Execution(format!(
114                "encoding must be a string literal, got {value:?}"
115            )));
116        }
117        ColumnarValue::Array(_) => {
118            return Err(DataFusionError::Execution(
119                "encoding must be a scalar string".to_string(),
120            ));
121        }
122    };
123
124    match encoding.to_ascii_lowercase().as_str() {
125        "dense" => Ok(PrimaryKeyEncoding::Dense),
126        "sparse" => Ok(PrimaryKeyEncoding::Sparse),
127        _ => Err(DataFusionError::Execution(format!(
128            "unsupported primary key encoding: {encoding}"
129        ))),
130    }
131}
132
133fn build_codec(
134    metadata: &RegionMetadata,
135    encoding: PrimaryKeyEncoding,
136) -> Arc<dyn PrimaryKeyCodec> {
137    let fields = metadata.primary_key_columns().map(|c| {
138        (
139            c.column_id,
140            SortField::new(c.column_schema.data_type.clone()),
141        )
142    });
143    build_primary_key_codec_with_fields(encoding, fields)
144}
145
146fn parse_region_metadata(arg: &ColumnarValue) -> datafusion_common::Result<RegionMetadata> {
147    let json = match arg {
148        ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
149        | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => v.as_str(),
150        ColumnarValue::Scalar(value) => {
151            return Err(DataFusionError::Execution(format!(
152                "region metadata must be a string literal, got {value:?}"
153            )));
154        }
155        ColumnarValue::Array(_) => {
156            return Err(DataFusionError::Execution(
157                "region metadata must be a scalar string".to_string(),
158            ));
159        }
160    };
161
162    RegionMetadata::from_json(json)
163        .map_err(|e| DataFusionError::Execution(format!("failed to parse region metadata: {e:?}")))
164}
165
166fn decode_primary_keys(
167    encoded: ArrayRef,
168    number_rows: usize,
169    codec: &dyn PrimaryKeyCodec,
170    name_lookup: &HashMap<ColumnId, String>,
171) -> datafusion_common::Result<Vec<Vec<NameValuePair>>> {
172    if let Some(dict) = encoded
173        .as_any()
174        .downcast_ref::<DictionaryArray<UInt32Type>>()
175    {
176        decode_dictionary(dict, number_rows, codec, name_lookup)
177    } else if let Some(array) = encoded.as_any().downcast_ref::<BinaryArray>() {
178        decode_binary_array(array, codec, name_lookup)
179    } else if let Some(array) = encoded.as_any().downcast_ref::<BinaryViewArray>() {
180        decode_binary_view_array(array, codec, name_lookup)
181    } else {
182        Err(DataFusionError::Execution(format!(
183            "column {PRIMARY_KEY_COLUMN_NAME} must be binary or dictionary(binary) array"
184        )))
185    }
186}
187
188fn decode_dictionary(
189    dict: &DictionaryArray<UInt32Type>,
190    number_rows: usize,
191    codec: &dyn PrimaryKeyCodec,
192    name_lookup: &HashMap<ColumnId, String>,
193) -> datafusion_common::Result<Vec<Vec<NameValuePair>>> {
194    let values = dict
195        .values()
196        .as_any()
197        .downcast_ref::<BinaryArray>()
198        .ok_or_else(|| {
199            DataFusionError::Execution("primary key dictionary values are not binary".to_string())
200        })?;
201
202    let mut decoded_values = Vec::with_capacity(values.len());
203    for i in 0..values.len() {
204        let pk = values.value(i);
205        let pairs = decode_one(pk, codec, name_lookup)?;
206        decoded_values.push(pairs);
207    }
208
209    let mut rows = Vec::with_capacity(number_rows);
210    let keys = dict.keys();
211    for i in 0..number_rows {
212        let dict_index = keys.value(i) as usize;
213        rows.push(decoded_values[dict_index].clone());
214    }
215
216    Ok(rows)
217}
218
219fn decode_binary_array(
220    array: &BinaryArray,
221    codec: &dyn PrimaryKeyCodec,
222    name_lookup: &HashMap<ColumnId, String>,
223) -> datafusion_common::Result<Vec<Vec<NameValuePair>>> {
224    (0..array.len())
225        .map(|i| decode_one(array.value(i), codec, name_lookup))
226        .collect()
227}
228
229fn decode_binary_view_array(
230    array: &BinaryViewArray,
231    codec: &dyn PrimaryKeyCodec,
232    name_lookup: &HashMap<ColumnId, String>,
233) -> datafusion_common::Result<Vec<Vec<NameValuePair>>> {
234    (0..array.len())
235        .map(|i| decode_one(array.value(i), codec, name_lookup))
236        .collect()
237}
238
239fn decode_one(
240    pk: &[u8],
241    codec: &dyn PrimaryKeyCodec,
242    name_lookup: &HashMap<ColumnId, String>,
243) -> datafusion_common::Result<Vec<NameValuePair>> {
244    let decoded = codec
245        .decode(pk)
246        .map_err(|e| DataFusionError::Execution(format!("failed to decode primary key: {e}")))?;
247
248    Ok(match decoded {
249        CompositeValues::Dense(values) => values
250            .into_iter()
251            .map(|(column_id, value)| (column_name(column_id, name_lookup), value_to_string(value)))
252            .collect(),
253        CompositeValues::Sparse(values) => {
254            let mut values: Vec<_> = values
255                .iter()
256                .map(|(column_id, value)| {
257                    (
258                        *column_id,
259                        column_name(*column_id, name_lookup),
260                        value_to_string(value.clone()),
261                    )
262                })
263                .collect();
264            values.sort_by_key(|(column_id, _, _)| {
265                (ReservedColumnId::is_reserved(*column_id), *column_id)
266            });
267            values
268                .into_iter()
269                .map(|(_, name, value)| (name, value))
270                .collect()
271        }
272    })
273}
274
275fn column_name(column_id: ColumnId, name_lookup: &HashMap<ColumnId, String>) -> String {
276    if let Some(name) = name_lookup.get(&column_id) {
277        return name.clone();
278    }
279
280    if column_id == ReservedColumnId::table_id() {
281        return "__table_id".to_string();
282    }
283    if column_id == ReservedColumnId::tsid() {
284        return "__tsid".to_string();
285    }
286
287    column_id.to_string()
288}
289
290fn value_to_string(value: Value) -> Option<String> {
291    match value {
292        Value::Null => None,
293        _ => Some(value.to_string()),
294    }
295}
296
297fn build_list_array(rows: &[Vec<NameValuePair>]) -> datafusion_common::Result<ArrayRef> {
298    let mut builder = ListBuilder::new(StringBuilder::new());
299
300    for row in rows {
301        for (key, value) in row {
302            let value = value.as_deref().unwrap_or(NULL_VALUE_LITERAL);
303            builder.values().append_value(format!("{key} : {value}"));
304        }
305        builder.append(true);
306    }
307
308    Ok(Arc::new(builder.finish()))
309}
310
311#[cfg(test)]
312mod tests {
313    use api::v1::SemanticType;
314    use datafusion_common::ScalarValue;
315    use datatypes::arrow::array::builder::BinaryDictionaryBuilder;
316    use datatypes::arrow::array::{BinaryArray, ListArray, StringArray};
317    use datatypes::arrow::datatypes::UInt32Type;
318    use datatypes::prelude::ConcreteDataType;
319    use datatypes::schema::ColumnSchema;
320    use datatypes::value::Value;
321    use mito_codec::row_converter::{
322        DensePrimaryKeyCodec, PrimaryKeyCodecExt, SortField, SparsePrimaryKeyCodec,
323    };
324    use store_api::codec::PrimaryKeyEncoding;
325    use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
326    use store_api::storage::consts::ReservedColumnId;
327    use store_api::storage::{ColumnId, RegionId};
328
329    use super::*;
330
331    fn pk_field() -> Arc<Field> {
332        Arc::new(Field::new_dictionary(
333            PRIMARY_KEY_COLUMN_NAME,
334            DataType::UInt32,
335            DataType::Binary,
336            false,
337        ))
338    }
339
340    fn region_metadata_json(
341        columns: &[(ColumnId, &str, ConcreteDataType)],
342        encoding: PrimaryKeyEncoding,
343    ) -> String {
344        let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
345        builder.push_column_metadata(ColumnMetadata {
346            column_schema: ColumnSchema::new(
347                "ts",
348                ConcreteDataType::timestamp_millisecond_datatype(),
349                false,
350            ),
351            semantic_type: SemanticType::Timestamp,
352            column_id: 100,
353        });
354        builder.primary_key_encoding(encoding);
355        for (id, name, ty) in columns {
356            builder.push_column_metadata(ColumnMetadata {
357                column_schema: ColumnSchema::new((*name).to_string(), ty.clone(), true),
358                semantic_type: SemanticType::Tag,
359                column_id: *id,
360            });
361        }
362        builder.primary_key(columns.iter().map(|(id, _, _)| *id).collect());
363
364        builder.build().unwrap().to_json().unwrap()
365    }
366
367    fn list_row(list: &ListArray, row_idx: usize) -> Vec<String> {
368        let values = list.value(row_idx);
369        let values = values.as_any().downcast_ref::<StringArray>().unwrap();
370        (0..values.len())
371            .map(|i| values.value(i).to_string())
372            .collect()
373    }
374
375    #[test]
376    fn test_decode_dense_primary_key() {
377        let columns = vec![
378            (0, "host", ConcreteDataType::string_datatype()),
379            (1, "core", ConcreteDataType::int64_datatype()),
380        ];
381        let metadata_json = region_metadata_json(&columns, PrimaryKeyEncoding::Dense);
382        let codec = DensePrimaryKeyCodec::with_fields(
383            columns
384                .iter()
385                .map(|(id, _, ty)| (*id, SortField::new(ty.clone())))
386                .collect(),
387        );
388
389        let rows = vec![
390            vec![Value::from("a"), Value::from(1_i64)],
391            vec![Value::from("b"), Value::from(2_i64)],
392            vec![Value::from("a"), Value::from(1_i64)],
393        ];
394
395        let mut builder = BinaryDictionaryBuilder::<UInt32Type>::new();
396        for row in &rows {
397            let encoded = codec.encode(row.iter().map(|v| v.as_value_ref())).unwrap();
398            builder.append(encoded.as_slice()).unwrap();
399        }
400        let dict_array: ArrayRef = Arc::new(builder.finish());
401
402        let args = ScalarFunctionArgs {
403            args: vec![
404                ColumnarValue::Array(dict_array),
405                ColumnarValue::Scalar(ScalarValue::Utf8(Some("dense".to_string()))),
406                ColumnarValue::Scalar(ScalarValue::Utf8(Some(metadata_json))),
407            ],
408            arg_fields: vec![
409                pk_field(),
410                Arc::new(Field::new("encoding", DataType::Utf8, false)),
411                Arc::new(Field::new("region_metadata", DataType::Utf8, false)),
412            ],
413            number_rows: 3,
414            return_field: Arc::new(Field::new(
415                "decoded",
416                DecodePrimaryKeyFunction::return_data_type(),
417                false,
418            )),
419            config_options: Default::default(),
420        };
421
422        let func = DecodePrimaryKeyFunction::default();
423        let result = func
424            .invoke_with_args(args)
425            .and_then(|v| v.to_array(3))
426            .unwrap();
427        let list = result.as_any().downcast_ref::<ListArray>().unwrap();
428
429        let expected = [
430            vec!["host : a".to_string(), "core : 1".to_string()],
431            vec!["host : b".to_string(), "core : 2".to_string()],
432            vec!["host : a".to_string(), "core : 1".to_string()],
433        ];
434
435        for (row_idx, expected_row) in expected.iter().enumerate() {
436            assert_eq!(*expected_row, list_row(list, row_idx));
437        }
438    }
439
440    #[test]
441    fn test_decode_sparse_primary_key() {
442        let columns = vec![
443            (10, "k0", ConcreteDataType::string_datatype()),
444            (11, "k1", ConcreteDataType::string_datatype()),
445        ];
446        let metadata_json = region_metadata_json(&columns, PrimaryKeyEncoding::Sparse);
447        let codec = SparsePrimaryKeyCodec::schemaless();
448
449        let rows = vec![
450            vec![
451                (ReservedColumnId::table_id(), Value::UInt32(1)),
452                (ReservedColumnId::tsid(), Value::UInt64(100)),
453                (10, Value::from("a")),
454                (11, Value::from("b")),
455            ],
456            vec![
457                (ReservedColumnId::table_id(), Value::UInt32(1)),
458                (ReservedColumnId::tsid(), Value::UInt64(200)),
459                (10, Value::from("c")),
460                (11, Value::from("d")),
461            ],
462        ];
463
464        let mut encoded_values = Vec::with_capacity(rows.len());
465        for row in &rows {
466            let mut buf = Vec::new();
467            codec.encode_values(row, &mut buf).unwrap();
468            encoded_values.push(buf);
469        }
470
471        let pk_array: ArrayRef = Arc::new(BinaryArray::from_iter_values(
472            encoded_values.iter().cloned(),
473        ));
474
475        let args = ScalarFunctionArgs {
476            args: vec![
477                ColumnarValue::Array(pk_array),
478                ColumnarValue::Scalar(ScalarValue::Utf8(Some("sparse".to_string()))),
479                ColumnarValue::Scalar(ScalarValue::Utf8(Some(metadata_json))),
480            ],
481            arg_fields: vec![
482                pk_field(),
483                Arc::new(Field::new("encoding", DataType::Utf8, false)),
484                Arc::new(Field::new("region_metadata", DataType::Utf8, false)),
485            ],
486            number_rows: rows.len(),
487            return_field: Arc::new(Field::new(
488                "decoded",
489                DecodePrimaryKeyFunction::return_data_type(),
490                false,
491            )),
492            config_options: Default::default(),
493        };
494
495        let func = DecodePrimaryKeyFunction::default();
496        let result = func
497            .invoke_with_args(args)
498            .and_then(|v| v.to_array(rows.len()))
499            .unwrap();
500        let list = result.as_any().downcast_ref::<ListArray>().unwrap();
501
502        let expected = [
503            vec![
504                "k0 : a".to_string(),
505                "k1 : b".to_string(),
506                "__tsid : 100".to_string(),
507                "__table_id : 1".to_string(),
508            ],
509            vec![
510                "k0 : c".to_string(),
511                "k1 : d".to_string(),
512                "__tsid : 200".to_string(),
513                "__table_id : 1".to_string(),
514            ],
515        ];
516
517        for (row_idx, expected_row) in expected.iter().enumerate() {
518            assert_eq!(*expected_row, list_row(list, row_idx));
519        }
520    }
521}