common_datasource/
compression.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::io;
17use std::str::FromStr;
18
19use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
20use async_compression::tokio::write;
21use bytes::Bytes;
22use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
23use futures::Stream;
24use serde::{Deserialize, Serialize};
25use strum::EnumIter;
26use tokio::io::{AsyncRead, AsyncWriteExt, BufReader};
27use tokio_util::io::{ReaderStream, StreamReader};
28
29use crate::error::{self, Error, Result};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum CompressionType {
34    /// Gzip-ed file
35    Gzip,
36    /// Bzip2-ed file
37    Bzip2,
38    /// Xz-ed file (liblzma)
39    Xz,
40    /// Zstd-ed file,
41    Zstd,
42    /// Uncompressed file
43    Uncompressed,
44}
45
46impl FromStr for CompressionType {
47    type Err = Error;
48
49    fn from_str(s: &str) -> Result<Self> {
50        let s = s.to_uppercase();
51        match s.as_str() {
52            "GZIP" | "GZ" => Ok(Self::Gzip),
53            "BZIP2" | "BZ2" => Ok(Self::Bzip2),
54            "XZ" => Ok(Self::Xz),
55            "ZST" | "ZSTD" => Ok(Self::Zstd),
56            "" => Ok(Self::Uncompressed),
57            _ => error::UnsupportedCompressionTypeSnafu {
58                compression_type: s,
59            }
60            .fail(),
61        }
62    }
63}
64
65impl Display for CompressionType {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.write_str(match self {
68            Self::Gzip => "GZIP",
69            Self::Bzip2 => "BZIP2",
70            Self::Xz => "XZ",
71            Self::Zstd => "ZSTD",
72            Self::Uncompressed => "",
73        })
74    }
75}
76
77impl CompressionType {
78    pub const fn is_compressed(&self) -> bool {
79        !matches!(self, &Self::Uncompressed)
80    }
81
82    pub const fn file_extension(&self) -> &'static str {
83        match self {
84            Self::Gzip => "gz",
85            Self::Bzip2 => "bz2",
86            Self::Xz => "xz",
87            Self::Zstd => "zst",
88            Self::Uncompressed => "",
89        }
90    }
91}
92
93macro_rules! impl_compression_type {
94    ($(($enum_item:ident, $prefix:ident)),*) => {
95        paste::item! {
96            use bytes::{Buf, BufMut, BytesMut};
97
98            impl CompressionType {
99                pub async fn encode<B: Buf>(&self, mut content: B) -> io::Result<Vec<u8>> {
100                    match self {
101                        $(
102                            CompressionType::$enum_item => {
103                                let mut buffer = Vec::with_capacity(content.remaining());
104                                let mut encoder = write::[<$prefix Encoder>]::new(&mut buffer);
105                                encoder.write_all_buf(&mut content).await?;
106                                encoder.shutdown().await?;
107                                Ok(buffer)
108                            }
109                        )*
110                        CompressionType::Uncompressed => {
111                            let mut bs = BytesMut::with_capacity(content.remaining());
112                            bs.put(content);
113                            Ok(bs.to_vec())
114                        },
115                    }
116                }
117
118                pub async fn decode<B: Buf>(&self, mut content: B) -> io::Result<Vec<u8>> {
119                    match self {
120                        $(
121                            CompressionType::$enum_item => {
122                                let mut buffer = Vec::with_capacity(content.remaining() * 2);
123                                let mut encoder = write::[<$prefix Decoder>]::new(&mut buffer);
124                                encoder.write_all_buf(&mut content).await?;
125                                encoder.shutdown().await?;
126                                Ok(buffer)
127                            }
128                        )*
129                        CompressionType::Uncompressed => {
130                            let mut bs = BytesMut::with_capacity(content.remaining());
131                            bs.put(content);
132                            Ok(bs.to_vec())
133                        },
134                    }
135                }
136
137                pub fn convert_async_read<T: AsyncRead + Unpin + Send + 'static>(
138                    &self,
139                    s: T,
140                ) -> Box<dyn AsyncRead + Unpin + Send> {
141                    match self {
142                        $(CompressionType::$enum_item => Box::new([<$prefix Decoder>]::new(BufReader::new(s))),)*
143                        CompressionType::Uncompressed => Box::new(s),
144                    }
145                }
146
147                pub fn convert_stream<T: Stream<Item = io::Result<Bytes>> + Unpin + Send + 'static>(
148                    &self,
149                    s: T,
150                ) -> Box<dyn Stream<Item = io::Result<Bytes>> + Send + Unpin> {
151                    match self {
152                        $(CompressionType::$enum_item => Box::new(ReaderStream::new([<$prefix Decoder>]::new(StreamReader::new(s)))),)*
153                        CompressionType::Uncompressed => Box::new(s),
154                    }
155                }
156            }
157
158            #[cfg(test)]
159            mod tests {
160                use super::CompressionType;
161
162                $(
163                #[tokio::test]
164                async fn [<test_ $enum_item:lower _compression>]() {
165                    let string = "foo_bar".as_bytes();
166                    let compress = CompressionType::$enum_item
167                        .encode(string)
168                        .await
169                        .unwrap();
170                    let decompress = CompressionType::$enum_item
171                        .decode(compress.as_slice())
172                        .await
173                        .unwrap();
174                    assert_eq!(decompress, string);
175                })*
176
177                #[tokio::test]
178                async fn test_uncompression() {
179                    let string = "foo_bar".as_bytes();
180                    let compress = CompressionType::Uncompressed
181                        .encode(string)
182                        .await
183                        .unwrap();
184                    let decompress = CompressionType::Uncompressed
185                        .decode(compress.as_slice())
186                        .await
187                        .unwrap();
188                    assert_eq!(decompress, string);
189                }
190            }
191        }
192    };
193}
194
195impl_compression_type!((Gzip, Gzip), (Bzip2, Bz), (Xz, Xz), (Zstd, Zstd));
196
197impl From<CompressionType> for FileCompressionType {
198    fn from(t: CompressionType) -> Self {
199        match t {
200            CompressionType::Gzip => FileCompressionType::GZIP,
201            CompressionType::Bzip2 => FileCompressionType::BZIP2,
202            CompressionType::Xz => FileCompressionType::XZ,
203            CompressionType::Zstd => FileCompressionType::ZSTD,
204            CompressionType::Uncompressed => FileCompressionType::UNCOMPRESSED,
205        }
206    }
207}