tests_fuzz/utils/
sql_dump_writer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fs::{OpenOptions, create_dir_all};
17use std::io::Write;
18use std::path::PathBuf;
19
20use snafu::ResultExt;
21
22use crate::error::{self, Result};
23use crate::utils::get_gt_fuzz_dump_buffer_max_bytes;
24
25/// Session writer for table-scoped SQL trace files.
26#[derive(Debug)]
27pub struct SqlDumpSession {
28    /// Session directory path.
29    pub run_dir: PathBuf,
30    /// Max in-memory buffer size before auto flush.
31    pub max_buffer_bytes: usize,
32    buffered_bytes: usize,
33    entries_by_table: HashMap<String, Vec<String>>,
34}
35
36impl SqlDumpSession {
37    /// Creates SQL dump session with default buffer limit.
38    pub fn new(run_dir: PathBuf) -> Result<Self> {
39        Self::new_with_buffer_limit(run_dir, get_gt_fuzz_dump_buffer_max_bytes())
40    }
41
42    /// Creates SQL dump session with custom buffer limit.
43    pub fn new_with_buffer_limit(run_dir: PathBuf, max_buffer_bytes: usize) -> Result<Self> {
44        create_dir_all(&run_dir).context(error::CreateFileSnafu {
45            path: run_dir.to_string_lossy().to_string(),
46        })?;
47
48        Ok(Self {
49            run_dir,
50            max_buffer_bytes,
51            buffered_bytes: 0,
52            entries_by_table: HashMap::new(),
53        })
54    }
55
56    /// Appends one SQL statement for a logical table.
57    pub fn append_sql(&mut self, table: &str, sql: &str, comment: Option<&str>) -> Result<()> {
58        let entry = format_sql_entry(sql, comment);
59        self.push_entry(table, entry)?;
60        Ok(())
61    }
62
63    /// Broadcasts one comment event to all table trace files.
64    pub fn broadcast_event<I, T>(&mut self, tables: I, event: &str, sql: &str) -> Result<()>
65    where
66        I: IntoIterator<Item = T>,
67        T: AsRef<str>,
68    {
69        let entry = format_sql_entry(sql, Some(event));
70        for table in tables {
71            self.push_entry(table.as_ref(), entry.clone())?;
72        }
73        Ok(())
74    }
75
76    /// Flushes all staged SQL traces to table-scoped files.
77    pub fn flush_all(&mut self) -> Result<()> {
78        self.flush_buffered_entries()
79    }
80
81    fn push_entry(&mut self, table: &str, entry: String) -> Result<()> {
82        self.buffered_bytes += entry.len();
83        self.entries_by_table
84            .entry(table.to_string())
85            .or_default()
86            .push(entry);
87
88        if self.buffered_bytes >= self.max_buffer_bytes {
89            self.flush_buffered_entries()?;
90        }
91        Ok(())
92    }
93
94    fn flush_buffered_entries(&mut self) -> Result<()> {
95        if self.entries_by_table.is_empty() {
96            return Ok(());
97        }
98
99        for (table, entries) in &self.entries_by_table {
100            let path = self
101                .run_dir
102                .join(format!("{}.trace.sql", sanitize_file_name(table)));
103            let mut file = OpenOptions::new()
104                .create(true)
105                .append(true)
106                .open(&path)
107                .context(error::CreateFileSnafu {
108                    path: path.to_string_lossy().to_string(),
109                })?;
110
111            for entry in entries {
112                file.write_all(entry.as_bytes())
113                    .context(error::WriteFileSnafu {
114                        path: path.to_string_lossy().to_string(),
115                    })?;
116                file.write_all(b"\n").context(error::WriteFileSnafu {
117                    path: path.to_string_lossy().to_string(),
118                })?;
119            }
120        }
121
122        self.entries_by_table.clear();
123        self.buffered_bytes = 0;
124        Ok(())
125    }
126}
127
128fn format_sql_entry(sql: &str, comment: Option<&str>) -> String {
129    let normalized_sql = normalize_sql(sql);
130    if let Some(comment) = comment {
131        format!("{}\n{normalized_sql}", format_comment(comment))
132    } else {
133        normalized_sql
134    }
135}
136
137fn format_comment(comment: &str) -> String {
138    comment
139        .lines()
140        .map(|line| format!("-- {line}"))
141        .collect::<Vec<_>>()
142        .join("\n")
143}
144
145fn normalize_sql(sql: &str) -> String {
146    let trimmed = sql.trim_end();
147    if trimmed.ends_with(';') {
148        trimmed.to_string()
149    } else {
150        format!("{trimmed};")
151    }
152}
153
154fn sanitize_file_name(raw: &str) -> String {
155    raw.chars()
156        .map(|ch| {
157            if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' {
158                ch
159            } else {
160                '_'
161            }
162        })
163        .collect()
164}
165
166#[cfg(test)]
167mod tests {
168    use std::time::{SystemTime, UNIX_EPOCH};
169
170    use super::SqlDumpSession;
171
172    #[test]
173    fn test_append_sql_writes_table_trace_file() {
174        let run_dir = std::env::temp_dir().join(format!(
175            "tests-fuzz-sql-dump-{}",
176            SystemTime::now()
177                .duration_since(UNIX_EPOCH)
178                .unwrap()
179                .as_millis()
180        ));
181
182        let mut session = SqlDumpSession::new_with_buffer_limit(run_dir.clone(), 1024).unwrap();
183        session
184            .append_sql(
185                "metric-a",
186                "INSERT INTO t VALUES(1)",
187                Some("kind=insert elapsed_ms=10"),
188            )
189            .unwrap();
190        session.flush_all().unwrap();
191
192        let content = std::fs::read_to_string(run_dir.join("metric-a.trace.sql")).unwrap();
193        assert!(content.contains("-- kind=insert elapsed_ms=10"));
194        assert!(content.contains("INSERT INTO t VALUES(1);"));
195    }
196
197    #[test]
198    fn test_broadcast_event_writes_to_all_tables() {
199        let run_dir = std::env::temp_dir().join(format!(
200            "tests-fuzz-sql-broadcast-{}",
201            SystemTime::now()
202                .duration_since(UNIX_EPOCH)
203                .unwrap()
204                .as_millis()
205        ));
206
207        let mut session = SqlDumpSession::new_with_buffer_limit(run_dir.clone(), 1024).unwrap();
208        session
209            .broadcast_event(
210                ["metric-a", "metric-b"],
211                "repartition action_idx=3",
212                "ALTER TABLE t REPARTITION",
213            )
214            .unwrap();
215        session.flush_all().unwrap();
216
217        let content_a = std::fs::read_to_string(run_dir.join("metric-a.trace.sql")).unwrap();
218        let content_b = std::fs::read_to_string(run_dir.join("metric-b.trace.sql")).unwrap();
219        assert!(content_a.contains("-- repartition action_idx=3"));
220        assert!(content_a.contains("ALTER TABLE t REPARTITION;"));
221        assert!(content_b.contains("-- repartition action_idx=3"));
222        assert!(content_b.contains("ALTER TABLE t REPARTITION;"));
223    }
224
225    #[test]
226    fn test_multiline_comment_is_prefixed_per_line() {
227        let run_dir = std::env::temp_dir().join(format!(
228            "tests-fuzz-sql-dump-comment-{}",
229            SystemTime::now()
230                .duration_since(UNIX_EPOCH)
231                .unwrap()
232                .as_millis()
233        ));
234
235        let mut session = SqlDumpSession::new_with_buffer_limit(run_dir.clone(), 1024).unwrap();
236        session
237            .append_sql(
238                "metric-a",
239                "INSERT INTO t VALUES(1)",
240                Some("kind=insert\nstarted_at_ms=1 elapsed_ms=2"),
241            )
242            .unwrap();
243        session.flush_all().unwrap();
244
245        let content = std::fs::read_to_string(run_dir.join("metric-a.trace.sql")).unwrap();
246        assert!(content.contains("-- kind=insert\n-- started_at_ms=1 elapsed_ms=2"));
247    }
248
249    #[test]
250    fn test_auto_flush_on_buffer_limit() {
251        let run_dir = std::env::temp_dir().join(format!(
252            "tests-fuzz-sql-dump-limit-{}",
253            SystemTime::now()
254                .duration_since(UNIX_EPOCH)
255                .unwrap()
256                .as_millis()
257        ));
258
259        let mut session = SqlDumpSession::new_with_buffer_limit(run_dir.clone(), 1).unwrap();
260        session
261            .append_sql("metric-a", "INSERT INTO t VALUES(1)", None)
262            .unwrap();
263
264        assert!(run_dir.join("metric-a.trace.sql").exists());
265        assert_eq!(session.buffered_bytes, 0);
266    }
267}