common_datasource/
compressed_writer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use async_compression::tokio::write::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
20use snafu::ResultExt;
21use tokio::io::{AsyncWrite, AsyncWriteExt};
22
23use crate::compression::CompressionType;
24use crate::error::{self, Result};
25
26/// A compressed writer that wraps an underlying async writer with compression.
27///
28/// This writer supports multiple compression formats including GZIP, BZIP2, XZ, and ZSTD.
29/// It provides transparent compression for any async writer implementation.
30pub struct CompressedWriter {
31    inner: Box<dyn AsyncWrite + Unpin + Send>,
32    compression_type: CompressionType,
33}
34
35impl CompressedWriter {
36    /// Creates a new compressed writer with the specified compression type.
37    ///
38    /// # Arguments
39    ///
40    /// * `writer` - The underlying writer to wrap with compression
41    /// * `compression_type` - The type of compression to apply
42    pub fn new(
43        writer: impl AsyncWrite + Unpin + Send + 'static,
44        compression_type: CompressionType,
45    ) -> Self {
46        let inner: Box<dyn AsyncWrite + Unpin + Send> = match compression_type {
47            CompressionType::Gzip => Box::new(GzipEncoder::new(writer)),
48            CompressionType::Bzip2 => Box::new(BzEncoder::new(writer)),
49            CompressionType::Xz => Box::new(XzEncoder::new(writer)),
50            CompressionType::Zstd => Box::new(ZstdEncoder::new(writer)),
51            CompressionType::Uncompressed => Box::new(writer),
52        };
53
54        Self {
55            inner,
56            compression_type,
57        }
58    }
59
60    /// Returns the compression type used by this writer.
61    pub fn compression_type(&self) -> CompressionType {
62        self.compression_type
63    }
64
65    /// Flush the writer and shutdown compression
66    pub async fn shutdown(mut self) -> Result<()> {
67        self.inner
68            .shutdown()
69            .await
70            .context(error::AsyncWriteSnafu)?;
71        Ok(())
72    }
73}
74
75impl AsyncWrite for CompressedWriter {
76    fn poll_write(
77        mut self: Pin<&mut Self>,
78        cx: &mut Context<'_>,
79        buf: &[u8],
80    ) -> Poll<io::Result<usize>> {
81        Pin::new(&mut self.inner).poll_write(cx, buf)
82    }
83
84    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85        Pin::new(&mut self.inner).poll_flush(cx)
86    }
87
88    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89        Pin::new(&mut self.inner).poll_shutdown(cx)
90    }
91}
92
93/// A trait for converting async writers into compressed writers.
94///
95/// This trait is automatically implemented for all types that implement [`AsyncWrite`].
96pub trait IntoCompressedWriter {
97    /// Converts this writer into a [`CompressedWriter`] with the specified compression type.
98    ///
99    /// # Arguments
100    ///
101    /// * `self` - The underlying writer to wrap with compression
102    /// * `compression_type` - The type of compression to apply
103    fn into_compressed_writer(self, compression_type: CompressionType) -> CompressedWriter
104    where
105        Self: AsyncWrite + Unpin + Send + 'static + Sized,
106    {
107        CompressedWriter::new(self, compression_type)
108    }
109}
110
111impl<W: AsyncWrite + Unpin + Send + 'static> IntoCompressedWriter for W {}
112
113#[cfg(test)]
114mod tests {
115    use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
116
117    use super::*;
118
119    #[tokio::test]
120    async fn test_compressed_writer_gzip() {
121        let (duplex_writer, mut duplex_reader) = duplex(1024);
122        let mut writer = duplex_writer.into_compressed_writer(CompressionType::Gzip);
123        let original = b"test data for gzip compression";
124
125        writer.write_all(original).await.unwrap();
126        writer.shutdown().await.unwrap();
127
128        let mut buffer = Vec::new();
129        duplex_reader.read_to_end(&mut buffer).await.unwrap();
130
131        // The compressed data should be different from the original
132        assert_ne!(buffer, original);
133        assert!(!buffer.is_empty());
134    }
135
136    #[tokio::test]
137    async fn test_compressed_writer_bzip2() {
138        let (duplex_writer, mut duplex_reader) = duplex(1024);
139        let mut writer = duplex_writer.into_compressed_writer(CompressionType::Bzip2);
140        let original = b"test data for bzip2 compression";
141
142        writer.write_all(original).await.unwrap();
143        writer.shutdown().await.unwrap();
144
145        let mut buffer = Vec::new();
146        duplex_reader.read_to_end(&mut buffer).await.unwrap();
147
148        // The compressed data should be different from the original
149        assert_ne!(buffer, original);
150        assert!(!buffer.is_empty());
151    }
152
153    #[tokio::test]
154    async fn test_compressed_writer_xz() {
155        let (duplex_writer, mut duplex_reader) = duplex(1024);
156        let mut writer = duplex_writer.into_compressed_writer(CompressionType::Xz);
157        let original = b"test data for xz compression";
158
159        writer.write_all(original).await.unwrap();
160        writer.shutdown().await.unwrap();
161
162        let mut buffer = Vec::new();
163        duplex_reader.read_to_end(&mut buffer).await.unwrap();
164
165        // The compressed data should be different from the original
166        assert_ne!(buffer, original);
167        assert!(!buffer.is_empty());
168    }
169
170    #[tokio::test]
171    async fn test_compressed_writer_zstd() {
172        let (duplex_writer, mut duplex_reader) = duplex(1024);
173        let mut writer = duplex_writer.into_compressed_writer(CompressionType::Zstd);
174        let original = b"test data for zstd compression";
175
176        writer.write_all(original).await.unwrap();
177        writer.shutdown().await.unwrap();
178
179        let mut buffer = Vec::new();
180        duplex_reader.read_to_end(&mut buffer).await.unwrap();
181
182        // The compressed data should be different from the original
183        assert_ne!(buffer, original);
184        assert!(!buffer.is_empty());
185    }
186
187    #[tokio::test]
188    async fn test_compressed_writer_uncompressed() {
189        let (duplex_writer, mut duplex_reader) = duplex(1024);
190        let mut writer = duplex_writer.into_compressed_writer(CompressionType::Uncompressed);
191        let original = b"test data for uncompressed";
192
193        writer.write_all(original).await.unwrap();
194        writer.shutdown().await.unwrap();
195
196        let mut buffer = Vec::new();
197        duplex_reader.read_to_end(&mut buffer).await.unwrap();
198
199        // Uncompressed data should be the same as the original
200        assert_eq!(buffer, original);
201    }
202}