1use std::collections::HashMap;
16
17use common_base::memory_limit::MemoryLimit;
18use serde::{Deserialize, Serialize};
19use store_api::storage::RegionId;
20use table::metadata::TableId;
21
22use crate::error::{Error, InvalidQueryContextExtensionSnafu, Result};
23
24pub const FLOW_INCREMENTAL_AFTER_SEQS: &str = "flow.incremental_after_seqs";
25pub const FLOW_INCREMENTAL_MODE: &str = "flow.incremental_mode";
26pub const FLOW_RETURN_REGION_SEQ: &str = "flow.return_region_seq";
27pub const FLOW_SINK_TABLE_ID: &str = "flow.sink_table_id";
28
29pub const FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY: &str = "memtable_only";
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(default)]
34pub struct QueryOptions {
35 pub parallelism: usize,
37 pub allow_query_fallback: bool,
39 pub memory_pool_size: MemoryLimit,
43}
44
45#[allow(clippy::derivable_impls)]
46impl Default for QueryOptions {
47 fn default() -> Self {
48 Self {
49 parallelism: 0,
50 allow_query_fallback: false,
51 memory_pool_size: MemoryLimit::default(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FlowIncrementalMode {
58 MemtableOnly,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
62pub struct FlowQueryExtensions {
63 pub incremental_after_seqs: Option<HashMap<u64, u64>>,
65 pub incremental_mode: Option<FlowIncrementalMode>,
67 pub return_region_seq: bool,
69 pub sink_table_id: Option<TableId>,
71}
72
73impl FlowQueryExtensions {
74 pub fn from_extensions(extensions: &HashMap<String, String>) -> Result<Self> {
75 let incremental_mode = extensions
76 .get(FLOW_INCREMENTAL_MODE)
77 .map(|value| match value.as_str() {
78 v if v.eq_ignore_ascii_case(FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY) => {
79 Ok(FlowIncrementalMode::MemtableOnly)
80 }
81 _ => Err(invalid_query_context_extension(format!(
82 "Invalid value for {}: {}",
83 FLOW_INCREMENTAL_MODE, value
84 ))),
85 })
86 .transpose()?;
87
88 let incremental_after_seqs = extensions
89 .get(FLOW_INCREMENTAL_AFTER_SEQS)
90 .map(|value| parse_incremental_after_seqs(value.as_str()))
91 .transpose()?;
92
93 let return_region_seq = extensions
94 .get(FLOW_RETURN_REGION_SEQ)
95 .map(|value| parse_bool(value.as_str()))
96 .transpose()?
97 .unwrap_or(false);
98
99 let sink_table_id = extensions
100 .get(FLOW_SINK_TABLE_ID)
101 .map(|value| {
102 value.parse::<TableId>().map_err(|_| {
103 invalid_query_context_extension(format!(
104 "Invalid value for {}: {}",
105 FLOW_SINK_TABLE_ID, value
106 ))
107 })
108 })
109 .transpose()?;
110
111 if matches!(incremental_mode, Some(FlowIncrementalMode::MemtableOnly)) {
112 let after_seqs = incremental_after_seqs.as_ref().ok_or_else(|| {
113 invalid_query_context_extension(format!(
114 "{} is required when {}={}.",
115 FLOW_INCREMENTAL_AFTER_SEQS,
116 FLOW_INCREMENTAL_MODE,
117 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
118 ))
119 })?;
120 if after_seqs.is_empty() {
121 return Err(invalid_query_context_extension(format!(
122 "{} must not be empty when {}={}.",
123 FLOW_INCREMENTAL_AFTER_SEQS,
124 FLOW_INCREMENTAL_MODE,
125 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
126 )));
127 }
128 }
129
130 Ok(Self {
131 incremental_after_seqs,
132 incremental_mode,
133 return_region_seq,
134 sink_table_id,
135 })
136 }
137
138 pub fn validate_for_scan(&self, source_region_id: RegionId) -> Result<bool> {
139 if self.sink_table_id.is_some() && self.sink_table_id == Some(source_region_id.table_id()) {
140 return Ok(false);
141 }
142
143 if matches!(
144 self.incremental_mode,
145 Some(FlowIncrementalMode::MemtableOnly)
146 ) {
147 let after_seqs = self.incremental_after_seqs.as_ref().ok_or_else(|| {
148 invalid_query_context_extension(format!(
149 "{} is required when {}=memtable_only.",
150 FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
151 ))
152 })?;
153
154 if !after_seqs.contains_key(&source_region_id.as_u64()) {
155 return Err(invalid_query_context_extension(format!(
156 "Missing region {} in {} when {}=memtable_only.",
157 source_region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
158 )));
159 }
160 }
161
162 Ok(self.incremental_after_seqs.is_some())
163 }
164
165 pub fn should_collect_region_watermark(&self) -> bool {
166 self.return_region_seq || self.incremental_after_seqs.is_some()
167 }
168}
169
170fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>> {
171 let raw = serde_json::from_str::<HashMap<String, serde_json::Value>>(value).map_err(|e| {
172 invalid_query_context_extension(format!(
173 "Invalid JSON for {}: {} ({})",
174 FLOW_INCREMENTAL_AFTER_SEQS, value, e
175 ))
176 })?;
177
178 raw.into_iter()
179 .map(|(region_id, raw_seq)| {
180 let region_id = region_id.parse::<u64>().map_err(|_| {
181 invalid_query_context_extension(format!(
182 "Invalid region id in {}: {}",
183 FLOW_INCREMENTAL_AFTER_SEQS, region_id
184 ))
185 })?;
186
187 let seq = match raw_seq {
188 serde_json::Value::Number(num) => num.as_u64().ok_or_else(|| {
189 invalid_query_context_extension(format!(
190 "Invalid sequence value in {} for region {}: {}",
191 FLOW_INCREMENTAL_AFTER_SEQS, region_id, num
192 ))
193 })?,
194 serde_json::Value::String(s) => s.parse::<u64>().map_err(|_| {
195 invalid_query_context_extension(format!(
196 "Invalid sequence string in {} for region {}: {}",
197 FLOW_INCREMENTAL_AFTER_SEQS, region_id, s
198 ))
199 })?,
200 _ => {
201 return Err(invalid_query_context_extension(format!(
202 "Invalid sequence value type in {} for region {}",
203 FLOW_INCREMENTAL_AFTER_SEQS, region_id
204 )));
205 }
206 };
207
208 Ok((region_id, seq))
209 })
210 .collect()
211}
212
213fn parse_bool(value: &str) -> Result<bool> {
214 match value {
215 v if v.eq_ignore_ascii_case("true") => Ok(true),
216 v if v.eq_ignore_ascii_case("false") => Ok(false),
217 _ => Err(invalid_query_context_extension(format!(
218 "Invalid value for {}: {}",
219 FLOW_RETURN_REGION_SEQ, value
220 ))),
221 }
222}
223
224fn invalid_query_context_extension(reason: String) -> Error {
225 InvalidQueryContextExtensionSnafu { reason }.build()
226}
227
228#[cfg(test)]
229mod flow_extension_tests {
230 use super::*;
231
232 #[test]
233 fn test_parse_flow_extensions_default() {
234 let exts = HashMap::new();
235 let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
236
237 assert_eq!(parsed.incremental_mode, None);
238 assert_eq!(parsed.incremental_after_seqs, None);
239 assert!(!parsed.return_region_seq);
240 assert_eq!(parsed.sink_table_id, None);
241 }
242
243 #[test]
244 fn test_parse_flow_extensions_memtable_only_success() {
245 let exts = HashMap::from([
246 (
247 FLOW_INCREMENTAL_MODE.to_string(),
248 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
249 ),
250 (
251 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
252 r#"{"1":10,"2":20}"#.to_string(),
253 ),
254 (FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string()),
255 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
256 ]);
257
258 let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
259 assert_eq!(
260 parsed.incremental_mode,
261 Some(FlowIncrementalMode::MemtableOnly)
262 );
263 assert_eq!(
264 parsed.incremental_after_seqs.unwrap(),
265 HashMap::from([(1, 10), (2, 20)])
266 );
267 assert!(parsed.return_region_seq);
268 assert_eq!(parsed.sink_table_id, Some(1024));
269 }
270
271 #[test]
272 fn test_parse_flow_extensions_mode_requires_after_seqs() {
273 let exts = HashMap::from([(
274 FLOW_INCREMENTAL_MODE.to_string(),
275 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
276 )]);
277
278 let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
279 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
280 }
281
282 #[test]
283 fn test_parse_flow_extensions_invalid_mode() {
284 let exts = HashMap::from([(FLOW_INCREMENTAL_MODE.to_string(), "foo".to_string())]);
285
286 let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
287 assert!(format!("{err}").contains(FLOW_INCREMENTAL_MODE));
288 }
289
290 #[test]
291 fn test_parse_flow_extensions_invalid_after_seqs_json() {
292 let exts = HashMap::from([
293 (
294 FLOW_INCREMENTAL_MODE.to_string(),
295 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
296 ),
297 (
298 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
299 "not-json".to_string(),
300 ),
301 ]);
302
303 let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
304 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
305 }
306
307 #[test]
308 fn test_parse_flow_extensions_after_seqs_string_values() {
309 let exts = HashMap::from([(
310 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
311 r#"{"1":"10","2":"20"}"#.to_string(),
312 )]);
313
314 let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
315 assert_eq!(
316 parsed.incremental_after_seqs.unwrap(),
317 HashMap::from([(1, 10), (2, 20)])
318 );
319 }
320
321 #[test]
322 fn test_parse_flow_extensions_after_seqs_invalid_value_type() {
323 let exts = HashMap::from([(
324 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
325 r#"{"1":true}"#.to_string(),
326 )]);
327
328 let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
329 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
330 }
331
332 #[test]
333 fn test_parse_flow_extensions_invalid_sink_table_id() {
334 let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "x".to_string())]);
335
336 let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
337 assert!(format!("{err}").contains(FLOW_SINK_TABLE_ID));
338 }
339
340 #[test]
341 fn test_validate_for_scan_missing_source_region() {
342 let source_region_id = RegionId::new(100, 2);
343 let existing_region_id = RegionId::new(100, 1);
344 let exts = HashMap::from([
345 (
346 FLOW_INCREMENTAL_MODE.to_string(),
347 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
348 ),
349 (
350 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
351 format!(r#"{{"{}":10}}"#, existing_region_id.as_u64()),
352 ),
353 ]);
354
355 let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
356 let err = parsed.validate_for_scan(source_region_id).unwrap_err();
357 assert!(format!("{err}").contains("Missing region"));
358 }
359
360 #[test]
361 fn test_validate_for_scan_sink_table_excluded() {
362 let source_region_id = RegionId::new(1024, 1);
363 let exts = HashMap::from([
364 (
365 FLOW_INCREMENTAL_MODE.to_string(),
366 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
367 ),
368 (
369 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
370 format!(r#"{{"{}":10}}"#, source_region_id.as_u64()),
371 ),
372 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
373 ]);
374
375 let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
376 let apply_incremental = parsed.validate_for_scan(source_region_id).unwrap();
377 assert!(!apply_incremental);
378 }
379
380 #[test]
381 fn test_should_collect_region_watermark_defaults_false() {
382 let parsed = FlowQueryExtensions::default();
383 assert!(!parsed.should_collect_region_watermark());
384 }
385
386 #[test]
387 fn test_should_collect_region_watermark_true_for_return_region_seq() {
388 let parsed = FlowQueryExtensions {
389 return_region_seq: true,
390 ..Default::default()
391 };
392 assert!(parsed.should_collect_region_watermark());
393 }
394
395 #[test]
396 fn test_should_collect_region_watermark_true_for_incremental_query() {
397 let parsed = FlowQueryExtensions {
398 incremental_after_seqs: Some(HashMap::from([(1, 10)])),
399 ..Default::default()
400 };
401 assert!(parsed.should_collect_region_watermark());
402 }
403}