common_datasource/
compression.rs1use std::fmt::Display;
16use std::io;
17use std::str::FromStr;
18
19use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
20use async_compression::tokio::write;
21use bytes::Bytes;
22use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
23use futures::Stream;
24use serde::{Deserialize, Serialize};
25use strum::EnumIter;
26use tokio::io::{AsyncRead, AsyncWriteExt, BufReader};
27use tokio_util::io::{ReaderStream, StreamReader};
28
29use crate::error::{self, Error, Result};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum CompressionType {
34 Gzip,
36 Bzip2,
38 Xz,
40 Zstd,
42 Uncompressed,
44}
45
46impl FromStr for CompressionType {
47 type Err = Error;
48
49 fn from_str(s: &str) -> Result<Self> {
50 let s = s.to_uppercase();
51 match s.as_str() {
52 "GZIP" | "GZ" => Ok(Self::Gzip),
53 "BZIP2" | "BZ2" => Ok(Self::Bzip2),
54 "XZ" => Ok(Self::Xz),
55 "ZST" | "ZSTD" => Ok(Self::Zstd),
56 "" => Ok(Self::Uncompressed),
57 _ => error::UnsupportedCompressionTypeSnafu {
58 compression_type: s,
59 }
60 .fail(),
61 }
62 }
63}
64
65impl Display for CompressionType {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.write_str(match self {
68 Self::Gzip => "GZIP",
69 Self::Bzip2 => "BZIP2",
70 Self::Xz => "XZ",
71 Self::Zstd => "ZSTD",
72 Self::Uncompressed => "",
73 })
74 }
75}
76
77impl CompressionType {
78 pub const fn is_compressed(&self) -> bool {
79 !matches!(self, &Self::Uncompressed)
80 }
81
82 pub const fn file_extension(&self) -> &'static str {
83 match self {
84 Self::Gzip => "gz",
85 Self::Bzip2 => "bz2",
86 Self::Xz => "xz",
87 Self::Zstd => "zst",
88 Self::Uncompressed => "",
89 }
90 }
91}
92
93macro_rules! impl_compression_type {
94 ($(($enum_item:ident, $prefix:ident)),*) => {
95 paste::item! {
96 use bytes::{Buf, BufMut, BytesMut};
97
98 impl CompressionType {
99 pub async fn encode<B: Buf>(&self, mut content: B) -> io::Result<Vec<u8>> {
100 match self {
101 $(
102 CompressionType::$enum_item => {
103 let mut buffer = Vec::with_capacity(content.remaining());
104 let mut encoder = write::[<$prefix Encoder>]::new(&mut buffer);
105 encoder.write_all_buf(&mut content).await?;
106 encoder.shutdown().await?;
107 Ok(buffer)
108 }
109 )*
110 CompressionType::Uncompressed => {
111 let mut bs = BytesMut::with_capacity(content.remaining());
112 bs.put(content);
113 Ok(bs.to_vec())
114 },
115 }
116 }
117
118 pub async fn decode<B: Buf>(&self, mut content: B) -> io::Result<Vec<u8>> {
119 match self {
120 $(
121 CompressionType::$enum_item => {
122 let mut buffer = Vec::with_capacity(content.remaining() * 2);
123 let mut encoder = write::[<$prefix Decoder>]::new(&mut buffer);
124 encoder.write_all_buf(&mut content).await?;
125 encoder.shutdown().await?;
126 Ok(buffer)
127 }
128 )*
129 CompressionType::Uncompressed => {
130 let mut bs = BytesMut::with_capacity(content.remaining());
131 bs.put(content);
132 Ok(bs.to_vec())
133 },
134 }
135 }
136
137 pub fn convert_async_read<T: AsyncRead + Unpin + Send + 'static>(
138 &self,
139 s: T,
140 ) -> Box<dyn AsyncRead + Unpin + Send> {
141 match self {
142 $(CompressionType::$enum_item => Box::new([<$prefix Decoder>]::new(BufReader::new(s))),)*
143 CompressionType::Uncompressed => Box::new(s),
144 }
145 }
146
147 pub fn convert_stream<T: Stream<Item = io::Result<Bytes>> + Unpin + Send + 'static>(
148 &self,
149 s: T,
150 ) -> Box<dyn Stream<Item = io::Result<Bytes>> + Send + Unpin> {
151 match self {
152 $(CompressionType::$enum_item => Box::new(ReaderStream::new([<$prefix Decoder>]::new(StreamReader::new(s)))),)*
153 CompressionType::Uncompressed => Box::new(s),
154 }
155 }
156 }
157
158 #[cfg(test)]
159 mod tests {
160 use super::CompressionType;
161
162 $(
163 #[tokio::test]
164 async fn [<test_ $enum_item:lower _compression>]() {
165 let string = "foo_bar".as_bytes();
166 let compress = CompressionType::$enum_item
167 .encode(string)
168 .await
169 .unwrap();
170 let decompress = CompressionType::$enum_item
171 .decode(compress.as_slice())
172 .await
173 .unwrap();
174 assert_eq!(decompress, string);
175 })*
176
177 #[tokio::test]
178 async fn test_uncompression() {
179 let string = "foo_bar".as_bytes();
180 let compress = CompressionType::Uncompressed
181 .encode(string)
182 .await
183 .unwrap();
184 let decompress = CompressionType::Uncompressed
185 .decode(compress.as_slice())
186 .await
187 .unwrap();
188 assert_eq!(decompress, string);
189 }
190 }
191 }
192 };
193}
194
195impl_compression_type!((Gzip, Gzip), (Bzip2, Bz), (Xz, Xz), (Zstd, Zstd));
196
197impl From<CompressionType> for FileCompressionType {
198 fn from(t: CompressionType) -> Self {
199 match t {
200 CompressionType::Gzip => FileCompressionType::GZIP,
201 CompressionType::Bzip2 => FileCompressionType::BZIP2,
202 CompressionType::Xz => FileCompressionType::XZ,
203 CompressionType::Zstd => FileCompressionType::ZSTD,
204 CompressionType::Uncompressed => FileCompressionType::UNCOMPRESSED,
205 }
206 }
207}