use std::io;
use std::ops::Range;
use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use bytes::{BufMut, Bytes};
use common_base::range_read::{Metadata, RangeReader, SizeAwareRangeReader};
use futures::{AsyncRead, AsyncSeek, AsyncWrite};
use object_store::ObjectStore;
use pin_project::pin_project;
use prometheus::IntCounter;
use snafu::ResultExt;
use crate::error::{OpenDalSnafu, Result};
#[derive(Clone)]
pub(crate) struct InstrumentedStore {
object_store: ObjectStore,
write_buffer_size: Option<usize>,
}
impl InstrumentedStore {
pub fn new(object_store: ObjectStore) -> Self {
Self {
object_store,
write_buffer_size: None,
}
}
pub fn with_write_buffer_size(mut self, write_buffer_size: Option<usize>) -> Self {
self.write_buffer_size = write_buffer_size.filter(|&size| size > 0);
self
}
pub async fn range_reader<'a>(
&self,
path: &str,
read_byte_count: &'a IntCounter,
read_count: &'a IntCounter,
) -> Result<InstrumentedRangeReader<'a>> {
Ok(InstrumentedRangeReader {
store: self.object_store.clone(),
path: path.to_string(),
read_byte_count,
read_count,
file_size_hint: None,
})
}
pub async fn reader<'a>(
&self,
path: &str,
read_byte_count: &'a IntCounter,
read_count: &'a IntCounter,
seek_count: &'a IntCounter,
) -> Result<InstrumentedAsyncRead<'a, object_store::FuturesAsyncReader>> {
let meta = self.object_store.stat(path).await.context(OpenDalSnafu)?;
let reader = self
.object_store
.reader(path)
.await
.context(OpenDalSnafu)?
.into_futures_async_read(0..meta.content_length())
.await
.context(OpenDalSnafu)?;
Ok(InstrumentedAsyncRead::new(
reader,
read_byte_count,
read_count,
seek_count,
))
}
pub async fn writer<'a>(
&self,
path: &str,
write_byte_count: &'a IntCounter,
write_count: &'a IntCounter,
flush_count: &'a IntCounter,
) -> Result<InstrumentedAsyncWrite<'a, object_store::FuturesAsyncWriter>> {
let writer = match self.write_buffer_size {
Some(size) => self
.object_store
.writer_with(path)
.chunk(size)
.await
.context(OpenDalSnafu)?
.into_futures_async_write(),
None => self
.object_store
.writer(path)
.await
.context(OpenDalSnafu)?
.into_futures_async_write(),
};
Ok(InstrumentedAsyncWrite::new(
writer,
write_byte_count,
write_count,
flush_count,
))
}
pub async fn list(&self, path: &str) -> Result<Vec<object_store::Entry>> {
let list = self.object_store.list(path).await.context(OpenDalSnafu)?;
Ok(list)
}
pub async fn remove_all(&self, path: &str) -> Result<()> {
self.object_store
.remove_all(path)
.await
.context(OpenDalSnafu)
}
}
#[pin_project]
pub(crate) struct InstrumentedAsyncRead<'a, R> {
#[pin]
inner: R,
read_byte_count: CounterGuard<'a>,
read_count: CounterGuard<'a>,
seek_count: CounterGuard<'a>,
}
impl<'a, R> InstrumentedAsyncRead<'a, R> {
fn new(
inner: R,
read_byte_count: &'a IntCounter,
read_count: &'a IntCounter,
seek_count: &'a IntCounter,
) -> Self {
Self {
inner,
read_byte_count: CounterGuard::new(read_byte_count),
read_count: CounterGuard::new(read_count),
seek_count: CounterGuard::new(seek_count),
}
}
}
impl<R: AsyncRead + Unpin + Send> AsyncRead for InstrumentedAsyncRead<'_, R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let poll = self.as_mut().project().inner.poll_read(cx, buf);
if let Poll::Ready(Ok(n)) = &poll {
self.read_count.inc_by(1);
self.read_byte_count.inc_by(*n);
}
poll
}
}
impl<R: AsyncSeek + Unpin + Send> AsyncSeek for InstrumentedAsyncRead<'_, R> {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
pos: io::SeekFrom,
) -> Poll<io::Result<u64>> {
let poll = self.as_mut().project().inner.poll_seek(cx, pos);
if let Poll::Ready(Ok(_)) = &poll {
self.seek_count.inc_by(1);
}
poll
}
}
#[pin_project]
pub(crate) struct InstrumentedAsyncWrite<'a, W> {
#[pin]
inner: W,
write_byte_count: CounterGuard<'a>,
write_count: CounterGuard<'a>,
flush_count: CounterGuard<'a>,
}
impl<'a, W> InstrumentedAsyncWrite<'a, W> {
fn new(
inner: W,
write_byte_count: &'a IntCounter,
write_count: &'a IntCounter,
flush_count: &'a IntCounter,
) -> Self {
Self {
inner,
write_byte_count: CounterGuard::new(write_byte_count),
write_count: CounterGuard::new(write_count),
flush_count: CounterGuard::new(flush_count),
}
}
}
impl<W: AsyncWrite + Unpin + Send> AsyncWrite for InstrumentedAsyncWrite<'_, W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let poll = self.as_mut().project().inner.poll_write(cx, buf);
if let Poll::Ready(Ok(n)) = &poll {
self.write_count.inc_by(1);
self.write_byte_count.inc_by(*n);
}
poll
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let poll = self.as_mut().project().inner.poll_flush(cx);
if let Poll::Ready(Ok(())) = &poll {
self.flush_count.inc_by(1);
}
poll
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
pub(crate) struct InstrumentedRangeReader<'a> {
store: ObjectStore,
path: String,
read_byte_count: &'a IntCounter,
read_count: &'a IntCounter,
file_size_hint: Option<u64>,
}
impl SizeAwareRangeReader for InstrumentedRangeReader<'_> {
fn with_file_size_hint(&mut self, file_size_hint: u64) {
self.file_size_hint = Some(file_size_hint);
}
}
#[async_trait]
impl RangeReader for InstrumentedRangeReader<'_> {
async fn metadata(&self) -> io::Result<Metadata> {
match self.file_size_hint {
Some(file_size_hint) => Ok(Metadata {
content_length: file_size_hint,
}),
None => {
let stat = self.store.stat(&self.path).await?;
Ok(Metadata {
content_length: stat.content_length(),
})
}
}
}
async fn read(&self, range: Range<u64>) -> io::Result<Bytes> {
let buf = self.store.reader(&self.path).await?.read(range).await?;
self.read_byte_count.inc_by(buf.len() as _);
self.read_count.inc_by(1);
Ok(buf.to_bytes())
}
async fn read_into(&self, range: Range<u64>, buf: &mut (impl BufMut + Send)) -> io::Result<()> {
let reader = self.store.reader(&self.path).await?;
let size = reader.read_into(buf, range).await?;
self.read_byte_count.inc_by(size as _);
self.read_count.inc_by(1);
Ok(())
}
async fn read_vec(&self, ranges: &[Range<u64>]) -> io::Result<Vec<Bytes>> {
let bufs = self
.store
.reader(&self.path)
.await?
.fetch(ranges.to_owned())
.await?;
let total_size: usize = bufs.iter().map(|buf| buf.len()).sum();
self.read_byte_count.inc_by(total_size as _);
self.read_count.inc_by(1);
Ok(bufs.into_iter().map(|buf| buf.to_bytes()).collect())
}
}
struct CounterGuard<'a> {
count: usize,
counter: &'a IntCounter,
}
impl<'a> CounterGuard<'a> {
fn new(counter: &'a IntCounter) -> Self {
Self { count: 0, counter }
}
fn inc_by(&mut self, n: usize) {
self.count += n;
}
}
impl Drop for CounterGuard<'_> {
fn drop(&mut self) {
if self.count > 0 {
self.counter.inc_by(self.count as _);
}
}
}
#[cfg(test)]
mod tests {
use futures::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
use object_store::services::Memory;
use super::*;
#[tokio::test]
async fn test_instrumented_store_read_write() {
let instrumented_store =
InstrumentedStore::new(ObjectStore::new(Memory::default()).unwrap().finish());
let read_byte_count = IntCounter::new("read_byte_count", "read_byte_count").unwrap();
let read_count = IntCounter::new("read_count", "read_count").unwrap();
let seek_count = IntCounter::new("seek_count", "seek_count").unwrap();
let write_byte_count = IntCounter::new("write_byte_count", "write_byte_count").unwrap();
let write_count = IntCounter::new("write_count", "write_count").unwrap();
let flush_count = IntCounter::new("flush_count", "flush_count").unwrap();
let mut writer = instrumented_store
.writer("my_file", &write_byte_count, &write_count, &flush_count)
.await
.unwrap();
writer.write_all(b"hello").await.unwrap();
writer.flush().await.unwrap();
writer.close().await.unwrap();
drop(writer);
let mut reader = instrumented_store
.reader("my_file", &read_byte_count, &read_count, &seek_count)
.await
.unwrap();
let mut buf = vec![0; 5];
reader.read_exact(&mut buf).await.unwrap();
reader.seek(io::SeekFrom::Start(0)).await.unwrap();
reader.read_exact(&mut buf).await.unwrap();
drop(reader);
assert_eq!(read_byte_count.get(), 10);
assert_eq!(read_count.get(), 2);
assert_eq!(seek_count.get(), 1);
assert_eq!(write_byte_count.get(), 5);
assert_eq!(write_count.get(), 1);
assert_eq!(flush_count.get(), 1);
}
}