common_datasource/
compressed_writer.rs1use std::io;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use async_compression::tokio::write::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
20use snafu::ResultExt;
21use tokio::io::{AsyncWrite, AsyncWriteExt};
22
23use crate::compression::CompressionType;
24use crate::error::{self, Result};
25
26pub struct CompressedWriter {
31 inner: Box<dyn AsyncWrite + Unpin + Send>,
32 compression_type: CompressionType,
33}
34
35impl CompressedWriter {
36 pub fn new(
43 writer: impl AsyncWrite + Unpin + Send + 'static,
44 compression_type: CompressionType,
45 ) -> Self {
46 let inner: Box<dyn AsyncWrite + Unpin + Send> = match compression_type {
47 CompressionType::Gzip => Box::new(GzipEncoder::new(writer)),
48 CompressionType::Bzip2 => Box::new(BzEncoder::new(writer)),
49 CompressionType::Xz => Box::new(XzEncoder::new(writer)),
50 CompressionType::Zstd => Box::new(ZstdEncoder::new(writer)),
51 CompressionType::Uncompressed => Box::new(writer),
52 };
53
54 Self {
55 inner,
56 compression_type,
57 }
58 }
59
60 pub fn compression_type(&self) -> CompressionType {
62 self.compression_type
63 }
64
65 pub async fn shutdown(mut self) -> Result<()> {
67 self.inner
68 .shutdown()
69 .await
70 .context(error::AsyncWriteSnafu)?;
71 Ok(())
72 }
73}
74
75impl AsyncWrite for CompressedWriter {
76 fn poll_write(
77 mut self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 buf: &[u8],
80 ) -> Poll<io::Result<usize>> {
81 Pin::new(&mut self.inner).poll_write(cx, buf)
82 }
83
84 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85 Pin::new(&mut self.inner).poll_flush(cx)
86 }
87
88 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89 Pin::new(&mut self.inner).poll_shutdown(cx)
90 }
91}
92
93pub trait IntoCompressedWriter {
97 fn into_compressed_writer(self, compression_type: CompressionType) -> CompressedWriter
104 where
105 Self: AsyncWrite + Unpin + Send + 'static + Sized,
106 {
107 CompressedWriter::new(self, compression_type)
108 }
109}
110
111impl<W: AsyncWrite + Unpin + Send + 'static> IntoCompressedWriter for W {}
112
113#[cfg(test)]
114mod tests {
115 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
116
117 use super::*;
118
119 #[tokio::test]
120 async fn test_compressed_writer_gzip() {
121 let (duplex_writer, mut duplex_reader) = duplex(1024);
122 let mut writer = duplex_writer.into_compressed_writer(CompressionType::Gzip);
123 let original = b"test data for gzip compression";
124
125 writer.write_all(original).await.unwrap();
126 writer.shutdown().await.unwrap();
127
128 let mut buffer = Vec::new();
129 duplex_reader.read_to_end(&mut buffer).await.unwrap();
130
131 assert_ne!(buffer, original);
133 assert!(!buffer.is_empty());
134 }
135
136 #[tokio::test]
137 async fn test_compressed_writer_bzip2() {
138 let (duplex_writer, mut duplex_reader) = duplex(1024);
139 let mut writer = duplex_writer.into_compressed_writer(CompressionType::Bzip2);
140 let original = b"test data for bzip2 compression";
141
142 writer.write_all(original).await.unwrap();
143 writer.shutdown().await.unwrap();
144
145 let mut buffer = Vec::new();
146 duplex_reader.read_to_end(&mut buffer).await.unwrap();
147
148 assert_ne!(buffer, original);
150 assert!(!buffer.is_empty());
151 }
152
153 #[tokio::test]
154 async fn test_compressed_writer_xz() {
155 let (duplex_writer, mut duplex_reader) = duplex(1024);
156 let mut writer = duplex_writer.into_compressed_writer(CompressionType::Xz);
157 let original = b"test data for xz compression";
158
159 writer.write_all(original).await.unwrap();
160 writer.shutdown().await.unwrap();
161
162 let mut buffer = Vec::new();
163 duplex_reader.read_to_end(&mut buffer).await.unwrap();
164
165 assert_ne!(buffer, original);
167 assert!(!buffer.is_empty());
168 }
169
170 #[tokio::test]
171 async fn test_compressed_writer_zstd() {
172 let (duplex_writer, mut duplex_reader) = duplex(1024);
173 let mut writer = duplex_writer.into_compressed_writer(CompressionType::Zstd);
174 let original = b"test data for zstd compression";
175
176 writer.write_all(original).await.unwrap();
177 writer.shutdown().await.unwrap();
178
179 let mut buffer = Vec::new();
180 duplex_reader.read_to_end(&mut buffer).await.unwrap();
181
182 assert_ne!(buffer, original);
184 assert!(!buffer.is_empty());
185 }
186
187 #[tokio::test]
188 async fn test_compressed_writer_uncompressed() {
189 let (duplex_writer, mut duplex_reader) = duplex(1024);
190 let mut writer = duplex_writer.into_compressed_writer(CompressionType::Uncompressed);
191 let original = b"test data for uncompressed";
192
193 writer.write_all(original).await.unwrap();
194 writer.shutdown().await.unwrap();
195
196 let mut buffer = Vec::new();
197 duplex_reader.read_to_end(&mut buffer).await.unwrap();
198
199 assert_eq!(buffer, original);
201 }
202}