common_base/
range_read.rsuse std::future::Future;
use std::io;
use std::ops::Range;
use std::path::Path;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use bytes::{BufMut, Bytes};
use futures::AsyncRead;
use pin_project::pin_project;
use tokio::io::{AsyncReadExt as _, AsyncSeekExt as _};
use tokio::sync::Mutex;
pub struct Metadata {
pub content_length: u64,
}
pub trait SizeAwareRangeReader: RangeReader {
fn with_file_size_hint(&mut self, file_size_hint: u64);
}
#[async_trait]
pub trait RangeReader: Sync + Send + Unpin {
async fn metadata(&self) -> io::Result<Metadata>;
async fn read(&self, range: Range<u64>) -> io::Result<Bytes>;
async fn read_into(&self, range: Range<u64>, buf: &mut (impl BufMut + Send)) -> io::Result<()> {
let bytes = self.read(range).await?;
buf.put_slice(&bytes);
Ok(())
}
async fn read_vec(&self, ranges: &[Range<u64>]) -> io::Result<Vec<Bytes>> {
let mut result = Vec::with_capacity(ranges.len());
for range in ranges {
result.push(self.read(range.clone()).await?);
}
Ok(result)
}
}
#[async_trait]
impl<R: ?Sized + RangeReader> RangeReader for &R {
async fn metadata(&self) -> io::Result<Metadata> {
(*self).metadata().await
}
async fn read(&self, range: Range<u64>) -> io::Result<Bytes> {
(*self).read(range).await
}
async fn read_into(&self, range: Range<u64>, buf: &mut (impl BufMut + Send)) -> io::Result<()> {
(*self).read_into(range, buf).await
}
async fn read_vec(&self, ranges: &[Range<u64>]) -> io::Result<Vec<Bytes>> {
(*self).read_vec(ranges).await
}
}
#[pin_project]
pub struct AsyncReadAdapter<R> {
inner: Arc<Mutex<R>>,
position: u64,
buffer: Vec<u8>,
content_length: u64,
#[pin]
read_fut: Option<Pin<Box<dyn Future<Output = io::Result<Bytes>> + Send>>>,
}
impl<R: RangeReader + 'static> AsyncReadAdapter<R> {
pub async fn new(inner: R) -> io::Result<Self> {
let inner = inner;
let metadata = inner.metadata().await?;
Ok(AsyncReadAdapter {
inner: Arc::new(Mutex::new(inner)),
position: 0,
buffer: Vec::new(),
content_length: metadata.content_length,
read_fut: None,
})
}
}
const MAX_SIZE_PER_READ: usize = 8 * 1024 * 1024; impl<R: RangeReader + 'static> AsyncRead for AsyncReadAdapter<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut this = self.as_mut().project();
if *this.position >= *this.content_length {
return Poll::Ready(Ok(0));
}
if !this.buffer.is_empty() {
let to_read = this.buffer.len().min(buf.len());
buf[..to_read].copy_from_slice(&this.buffer[..to_read]);
this.buffer.drain(..to_read);
*this.position += to_read as u64;
return Poll::Ready(Ok(to_read));
}
if this.read_fut.is_none() {
let size = (*this.content_length - *this.position).min(MAX_SIZE_PER_READ as u64);
let range = *this.position..(*this.position + size);
let inner = this.inner.clone();
let fut = async move {
let inner = inner.lock().await;
inner.read(range).await
};
*this.read_fut = Some(Box::pin(fut));
}
match this
.read_fut
.as_mut()
.as_pin_mut()
.expect("checked above")
.poll(cx)
{
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(bytes)) => {
*this.read_fut = None;
if !bytes.is_empty() {
this.buffer.extend_from_slice(&bytes);
self.poll_read(cx, buf)
} else {
Poll::Ready(Ok(0))
}
}
Poll::Ready(Err(e)) => {
*this.read_fut = None;
Poll::Ready(Err(e))
}
}
}
}
#[async_trait]
impl RangeReader for Vec<u8> {
async fn metadata(&self) -> io::Result<Metadata> {
Ok(Metadata {
content_length: self.len() as u64,
})
}
async fn read(&self, range: Range<u64>) -> io::Result<Bytes> {
let bytes = Bytes::copy_from_slice(&self[range.start as usize..range.end as usize]);
Ok(bytes)
}
}
pub struct FileReader {
content_length: u64,
position: AtomicU64,
file: Mutex<tokio::fs::File>,
}
impl FileReader {
pub async fn new(path: impl AsRef<Path>) -> io::Result<Self> {
let file = tokio::fs::File::open(path).await?;
let metadata = file.metadata().await?;
Ok(FileReader {
content_length: metadata.len(),
position: AtomicU64::new(0),
file: Mutex::new(file),
})
}
}
#[cfg(any(test, feature = "testing"))]
impl SizeAwareRangeReader for FileReader {
fn with_file_size_hint(&mut self, _file_size_hint: u64) {
}
}
#[async_trait]
impl RangeReader for FileReader {
async fn metadata(&self) -> io::Result<Metadata> {
Ok(Metadata {
content_length: self.content_length,
})
}
async fn read(&self, mut range: Range<u64>) -> io::Result<Bytes> {
let mut file = self.file.lock().await;
if range.start != self.position.load(Ordering::Relaxed) {
file.seek(io::SeekFrom::Start(range.start)).await?;
self.position.store(range.start, Ordering::Relaxed);
}
range.end = range.end.min(self.content_length);
if range.end <= self.position.load(Ordering::Relaxed) {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Start of range is out of bounds",
));
}
let mut buf = vec![0; (range.end - range.start) as usize];
file.read_exact(&mut buf).await?;
self.position.store(range.end, Ordering::Relaxed);
Ok(Bytes::from(buf))
}
}
#[cfg(test)]
mod tests {
use common_test_util::temp_dir::create_named_temp_file;
use futures::io::AsyncReadExt as _;
use super::*;
#[tokio::test]
async fn test_async_read_adapter() {
let data = b"hello world";
let reader = Vec::from(data);
let mut adapter = AsyncReadAdapter::new(reader).await.unwrap();
let mut buf = Vec::new();
adapter.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, data);
}
#[tokio::test]
async fn test_async_read_adapter_large() {
let data = (0..20 * 1024 * 1024).map(|i| i as u8).collect::<Vec<u8>>();
let mut adapter = AsyncReadAdapter::new(data.clone()).await.unwrap();
let mut buf = Vec::new();
adapter.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, data);
}
#[tokio::test]
async fn test_file_reader() {
let file = create_named_temp_file();
let path = file.path();
let data = b"hello world";
tokio::fs::write(path, data).await.unwrap();
let reader = FileReader::new(path).await.unwrap();
let metadata = reader.metadata().await.unwrap();
assert_eq!(metadata.content_length, data.len() as u64);
let bytes = reader.read(0..metadata.content_length).await.unwrap();
assert_eq!(&*bytes, data);
let bytes = reader.read(0..5).await.unwrap();
assert_eq!(&*bytes, &data[..5]);
}
}