common_base/
range_read.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16use std::io;
17use std::ops::Range;
18use std::path::Path;
19use std::pin::Pin;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use async_trait::async_trait;
25use bytes::{BufMut, Bytes};
26use futures::AsyncRead;
27use pin_project::pin_project;
28use tokio::io::{AsyncReadExt as _, AsyncSeekExt as _};
29use tokio::sync::Mutex;
30
31/// `Metadata` contains the metadata of a source.
32pub struct Metadata {
33    /// The length of the source in bytes.
34    pub content_length: u64,
35}
36
37/// `SizeAwareRangeReader` is a `RangeReader` that supports setting a file size hint.
38pub trait SizeAwareRangeReader: RangeReader {
39    /// Sets the file size hint for the reader.
40    ///
41    /// It's used to optimize the reading process by reducing the number of remote requests.
42    fn with_file_size_hint(&mut self, file_size_hint: u64);
43}
44
45/// `RangeReader` reads a range of bytes from a source.
46#[async_trait]
47pub trait RangeReader: Sync + Send + Unpin {
48    /// Returns the metadata of the source.
49    async fn metadata(&self) -> io::Result<Metadata>;
50
51    /// Reads the bytes in the given range.
52    async fn read(&self, range: Range<u64>) -> io::Result<Bytes>;
53
54    /// Reads the bytes in the given range into the buffer.
55    ///
56    /// Handles the buffer based on its capacity:
57    /// - If the buffer is insufficient to hold the bytes, it will either:
58    ///   - Allocate additional space (e.g., for `Vec<u8>`)
59    ///   - Panic (e.g., for `&mut [u8]`)
60    async fn read_into(&self, range: Range<u64>, buf: &mut (impl BufMut + Send)) -> io::Result<()> {
61        let bytes = self.read(range).await?;
62        buf.put_slice(&bytes);
63        Ok(())
64    }
65
66    /// Reads the bytes in the given ranges.
67    async fn read_vec(&self, ranges: &[Range<u64>]) -> io::Result<Vec<Bytes>> {
68        let mut result = Vec::with_capacity(ranges.len());
69        for range in ranges {
70            result.push(self.read(range.clone()).await?);
71        }
72        Ok(result)
73    }
74}
75
76#[async_trait]
77impl<R: ?Sized + RangeReader> RangeReader for &R {
78    async fn metadata(&self) -> io::Result<Metadata> {
79        (*self).metadata().await
80    }
81
82    async fn read(&self, range: Range<u64>) -> io::Result<Bytes> {
83        (*self).read(range).await
84    }
85
86    async fn read_into(&self, range: Range<u64>, buf: &mut (impl BufMut + Send)) -> io::Result<()> {
87        (*self).read_into(range, buf).await
88    }
89
90    async fn read_vec(&self, ranges: &[Range<u64>]) -> io::Result<Vec<Bytes>> {
91        (*self).read_vec(ranges).await
92    }
93}
94
95/// `AsyncReadAdapter` adapts a `RangeReader` to an `AsyncRead`.
96#[pin_project]
97pub struct AsyncReadAdapter<R> {
98    /// The inner `RangeReader`.
99    /// Use `Mutex` to get rid of the borrow checker issue.
100    inner: Arc<Mutex<R>>,
101
102    /// The current position from the view of the reader.
103    position: u64,
104
105    /// The buffer for the read bytes.
106    buffer: Vec<u8>,
107
108    /// The length of the content.
109    content_length: u64,
110
111    /// The future for reading the next bytes.
112    #[pin]
113    read_fut: Option<Pin<Box<dyn Future<Output = io::Result<Bytes>> + Send>>>,
114}
115
116impl<R: RangeReader + 'static> AsyncReadAdapter<R> {
117    pub async fn new(inner: R) -> io::Result<Self> {
118        let inner = inner;
119        let metadata = inner.metadata().await?;
120        Ok(AsyncReadAdapter {
121            inner: Arc::new(Mutex::new(inner)),
122            position: 0,
123            buffer: Vec::new(),
124            content_length: metadata.content_length,
125            read_fut: None,
126        })
127    }
128}
129
130/// The maximum size per read for the inner reader in `AsyncReadAdapter`.
131const MAX_SIZE_PER_READ: usize = 8 * 1024 * 1024; // 8MB
132
133impl<R: RangeReader + 'static> AsyncRead for AsyncReadAdapter<R> {
134    fn poll_read(
135        mut self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137        buf: &mut [u8],
138    ) -> Poll<io::Result<usize>> {
139        let mut this = self.as_mut().project();
140
141        if *this.position >= *this.content_length {
142            return Poll::Ready(Ok(0));
143        }
144
145        if !this.buffer.is_empty() {
146            let to_read = this.buffer.len().min(buf.len());
147            buf[..to_read].copy_from_slice(&this.buffer[..to_read]);
148            this.buffer.drain(..to_read);
149            *this.position += to_read as u64;
150            return Poll::Ready(Ok(to_read));
151        }
152
153        if this.read_fut.is_none() {
154            let size = (*this.content_length - *this.position).min(MAX_SIZE_PER_READ as u64);
155            let range = *this.position..(*this.position + size);
156            let inner = this.inner.clone();
157            let fut = async move {
158                let inner = inner.lock().await;
159                inner.read(range).await
160            };
161
162            *this.read_fut = Some(Box::pin(fut));
163        }
164
165        match this
166            .read_fut
167            .as_mut()
168            .as_pin_mut()
169            .expect("checked above")
170            .poll(cx)
171        {
172            Poll::Pending => Poll::Pending,
173            Poll::Ready(Ok(bytes)) => {
174                *this.read_fut = None;
175
176                if !bytes.is_empty() {
177                    this.buffer.extend_from_slice(&bytes);
178                    self.poll_read(cx, buf)
179                } else {
180                    Poll::Ready(Ok(0))
181                }
182            }
183            Poll::Ready(Err(e)) => {
184                *this.read_fut = None;
185                Poll::Ready(Err(e))
186            }
187        }
188    }
189}
190
191#[async_trait]
192impl RangeReader for Vec<u8> {
193    async fn metadata(&self) -> io::Result<Metadata> {
194        Ok(Metadata {
195            content_length: self.len() as u64,
196        })
197    }
198
199    async fn read(&self, range: Range<u64>) -> io::Result<Bytes> {
200        let bytes = Bytes::copy_from_slice(&self[range.start as usize..range.end as usize]);
201        Ok(bytes)
202    }
203}
204
205// TODO(weny): considers replacing `tokio::fs::File` with opendal reader.
206/// `FileReader` is a `RangeReader` for reading a file.
207pub struct FileReader {
208    content_length: u64,
209    position: AtomicU64,
210    file: Mutex<tokio::fs::File>,
211}
212
213impl FileReader {
214    /// Creates a new `FileReader` for the file at the given path.
215    pub async fn new(path: impl AsRef<Path>) -> io::Result<Self> {
216        let file = tokio::fs::File::open(path).await?;
217        let metadata = file.metadata().await?;
218        Ok(FileReader {
219            content_length: metadata.len(),
220            position: AtomicU64::new(0),
221            file: Mutex::new(file),
222        })
223    }
224}
225
226impl SizeAwareRangeReader for FileReader {
227    fn with_file_size_hint(&mut self, _file_size_hint: u64) {
228        // do nothing
229    }
230}
231
232#[async_trait]
233impl RangeReader for FileReader {
234    async fn metadata(&self) -> io::Result<Metadata> {
235        Ok(Metadata {
236            content_length: self.content_length,
237        })
238    }
239
240    async fn read(&self, mut range: Range<u64>) -> io::Result<Bytes> {
241        let mut file = self.file.lock().await;
242
243        if range.start != self.position.load(Ordering::Relaxed) {
244            file.seek(io::SeekFrom::Start(range.start)).await?;
245            self.position.store(range.start, Ordering::Relaxed);
246        }
247
248        range.end = range.end.min(self.content_length);
249        if range.end <= self.position.load(Ordering::Relaxed) {
250            return Err(io::Error::new(
251                io::ErrorKind::UnexpectedEof,
252                "Start of range is out of bounds",
253            ));
254        }
255
256        let mut buf = vec![0; (range.end - range.start) as usize];
257
258        file.read_exact(&mut buf).await?;
259        self.position.store(range.end, Ordering::Relaxed);
260
261        Ok(Bytes::from(buf))
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use common_test_util::temp_dir::create_named_temp_file;
268    use futures::io::AsyncReadExt as _;
269
270    use super::*;
271
272    #[tokio::test]
273    async fn test_async_read_adapter() {
274        let data = b"hello world";
275        let reader = Vec::from(data);
276        let mut adapter = AsyncReadAdapter::new(reader).await.unwrap();
277
278        let mut buf = Vec::new();
279        adapter.read_to_end(&mut buf).await.unwrap();
280        assert_eq!(buf, data);
281    }
282
283    #[tokio::test]
284    async fn test_async_read_adapter_large() {
285        let data = (0..20 * 1024 * 1024).map(|i| i as u8).collect::<Vec<u8>>();
286        let mut adapter = AsyncReadAdapter::new(data.clone()).await.unwrap();
287
288        let mut buf = Vec::new();
289        adapter.read_to_end(&mut buf).await.unwrap();
290        assert_eq!(buf, data);
291    }
292
293    #[tokio::test]
294    async fn test_file_reader() {
295        let file = create_named_temp_file();
296        let path = file.path();
297        let data = b"hello world";
298        tokio::fs::write(path, data).await.unwrap();
299
300        let reader = FileReader::new(path).await.unwrap();
301        let metadata = reader.metadata().await.unwrap();
302        assert_eq!(metadata.content_length, data.len() as u64);
303
304        let bytes = reader.read(0..metadata.content_length).await.unwrap();
305        assert_eq!(&*bytes, data);
306
307        let bytes = reader.read(0..5).await.unwrap();
308        assert_eq!(&*bytes, &data[..5]);
309    }
310}