use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use common_time::Timestamp;
use datatypes::arrow::datatypes::SchemaRef;
use object_store::{FuturesAsyncWriter, ObjectStore};
use parquet::arrow::AsyncArrowWriter;
use parquet::basic::{Compression, Encoding, ZstdLevel};
use parquet::file::metadata::KeyValue;
use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder};
use parquet::schema::types::ColumnPath;
use smallvec::smallvec;
use snafu::ResultExt;
use store_api::metadata::RegionMetadataRef;
use store_api::storage::consts::SEQUENCE_COLUMN_NAME;
use store_api::storage::SequenceNumber;
use tokio::io::AsyncWrite;
use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt};
use crate::access_layer::{FilePathProvider, SstInfoArray};
use crate::error::{InvalidMetadataSnafu, OpenDalSnafu, Result, WriteParquetSnafu};
use crate::read::{Batch, Source};
use crate::sst::file::FileId;
use crate::sst::index::{Indexer, IndexerBuilder};
use crate::sst::parquet::format::WriteFormat;
use crate::sst::parquet::helper::parse_parquet_metadata;
use crate::sst::parquet::{SstInfo, WriteOptions, PARQUET_METADATA_KEY};
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
pub struct ParquetWriter<F: WriterFactory, I: IndexerBuilder, P: FilePathProvider> {
path_provider: P,
writer: Option<AsyncArrowWriter<SizeAwareWriter<F::Writer>>>,
current_file: FileId,
writer_factory: F,
metadata: RegionMetadataRef,
indexer_builder: I,
current_indexer: Option<Indexer>,
bytes_written: Arc<AtomicUsize>,
}
pub trait WriterFactory {
type Writer: AsyncWrite + Send + Unpin;
fn create(&mut self, file_path: &str) -> impl Future<Output = Result<Self::Writer>>;
}
pub struct ObjectStoreWriterFactory {
object_store: ObjectStore,
}
impl WriterFactory for ObjectStoreWriterFactory {
type Writer = Compat<FuturesAsyncWriter>;
async fn create(&mut self, file_path: &str) -> Result<Self::Writer> {
self.object_store
.writer_with(file_path)
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
.concurrent(DEFAULT_WRITE_CONCURRENCY)
.await
.map(|v| v.into_futures_async_write().compat_write())
.context(OpenDalSnafu)
}
}
impl<I, P> ParquetWriter<ObjectStoreWriterFactory, I, P>
where
P: FilePathProvider,
I: IndexerBuilder,
{
pub async fn new_with_object_store(
object_store: ObjectStore,
metadata: RegionMetadataRef,
indexer_builder: I,
path_provider: P,
) -> ParquetWriter<ObjectStoreWriterFactory, I, P> {
ParquetWriter::new(
ObjectStoreWriterFactory { object_store },
metadata,
indexer_builder,
path_provider,
)
.await
}
}
impl<F, I, P> ParquetWriter<F, I, P>
where
F: WriterFactory,
I: IndexerBuilder,
P: FilePathProvider,
{
pub async fn new(
factory: F,
metadata: RegionMetadataRef,
indexer_builder: I,
path_provider: P,
) -> ParquetWriter<F, I, P> {
let init_file = FileId::random();
let indexer = indexer_builder.build(init_file).await;
ParquetWriter {
path_provider,
writer: None,
current_file: init_file,
writer_factory: factory,
metadata,
indexer_builder,
current_indexer: Some(indexer),
bytes_written: Arc::new(AtomicUsize::new(0)),
}
}
async fn get_or_create_indexer(&mut self) -> &mut Indexer {
match self.current_indexer {
None => {
self.current_file = FileId::random();
let indexer = self.indexer_builder.build(self.current_file).await;
self.current_indexer = Some(indexer);
self.current_indexer.as_mut().unwrap()
}
Some(ref mut indexer) => indexer,
}
}
pub async fn write_all(
&mut self,
mut source: Source,
override_sequence: Option<SequenceNumber>, opts: &WriteOptions,
) -> Result<SstInfoArray> {
let write_format =
WriteFormat::new(self.metadata.clone()).with_override_sequence(override_sequence);
let mut stats = SourceStats::default();
while let Some(res) = self
.write_next_batch(&mut source, &write_format, opts)
.await
.transpose()
{
match res {
Ok(mut batch) => {
stats.update(&batch);
self.get_or_create_indexer().await.update(&mut batch).await;
}
Err(e) => {
self.get_or_create_indexer().await.abort().await;
return Err(e);
}
}
}
let index_output = self.get_or_create_indexer().await.finish().await;
if stats.num_rows == 0 {
return Ok(smallvec![]);
}
let Some(mut arrow_writer) = self.writer.take() else {
return Ok(smallvec![]);
};
arrow_writer.flush().await.context(WriteParquetSnafu)?;
let file_meta = arrow_writer.close().await.context(WriteParquetSnafu)?;
let file_size = self.bytes_written.load(Ordering::Relaxed) as u64;
let time_range = stats.time_range.unwrap();
let parquet_metadata = parse_parquet_metadata(file_meta)?;
let file_id = self.current_file;
Ok(smallvec![SstInfo {
file_id,
time_range,
file_size,
num_rows: stats.num_rows,
num_row_groups: parquet_metadata.num_row_groups() as u64,
file_metadata: Some(Arc::new(parquet_metadata)),
index_metadata: index_output,
}])
}
fn customize_column_config(
builder: WriterPropertiesBuilder,
region_metadata: &RegionMetadataRef,
) -> WriterPropertiesBuilder {
let ts_col = ColumnPath::new(vec![region_metadata
.time_index_column()
.column_schema
.name
.clone()]);
let seq_col = ColumnPath::new(vec![SEQUENCE_COLUMN_NAME.to_string()]);
builder
.set_column_encoding(seq_col.clone(), Encoding::DELTA_BINARY_PACKED)
.set_column_dictionary_enabled(seq_col, false)
.set_column_encoding(ts_col.clone(), Encoding::DELTA_BINARY_PACKED)
.set_column_dictionary_enabled(ts_col, false)
}
async fn write_next_batch(
&mut self,
source: &mut Source,
write_format: &WriteFormat,
opts: &WriteOptions,
) -> Result<Option<Batch>> {
let Some(batch) = source.next_batch().await? else {
return Ok(None);
};
let arrow_batch = write_format.convert_batch(&batch)?;
self.maybe_init_writer(write_format.arrow_schema(), opts)
.await?
.write(&arrow_batch)
.await
.context(WriteParquetSnafu)?;
Ok(Some(batch))
}
async fn maybe_init_writer(
&mut self,
schema: &SchemaRef,
opts: &WriteOptions,
) -> Result<&mut AsyncArrowWriter<SizeAwareWriter<F::Writer>>> {
if let Some(ref mut w) = self.writer {
Ok(w)
} else {
let json = self.metadata.to_json().context(InvalidMetadataSnafu)?;
let key_value_meta = KeyValue::new(PARQUET_METADATA_KEY.to_string(), json);
let props_builder = WriterProperties::builder()
.set_key_value_metadata(Some(vec![key_value_meta]))
.set_compression(Compression::ZSTD(ZstdLevel::default()))
.set_encoding(Encoding::PLAIN)
.set_max_row_group_size(opts.row_group_size);
let props_builder = Self::customize_column_config(props_builder, &self.metadata);
let writer_props = props_builder.build();
let sst_file_path = self.path_provider.build_sst_file_path(self.current_file);
let writer = SizeAwareWriter::new(
self.writer_factory.create(&sst_file_path).await?,
self.bytes_written.clone(),
);
let arrow_writer =
AsyncArrowWriter::try_new(writer, schema.clone(), Some(writer_props))
.context(WriteParquetSnafu)?;
self.writer = Some(arrow_writer);
Ok(self.writer.as_mut().unwrap())
}
}
}
#[derive(Default)]
struct SourceStats {
num_rows: usize,
time_range: Option<(Timestamp, Timestamp)>,
}
impl SourceStats {
fn update(&mut self, batch: &Batch) {
if batch.is_empty() {
return;
}
self.num_rows += batch.num_rows();
let (min_in_batch, max_in_batch) = (
batch.first_timestamp().unwrap(),
batch.last_timestamp().unwrap(),
);
if let Some(time_range) = &mut self.time_range {
time_range.0 = time_range.0.min(min_in_batch);
time_range.1 = time_range.1.max(max_in_batch);
} else {
self.time_range = Some((min_in_batch, max_in_batch));
}
}
}
struct SizeAwareWriter<W> {
inner: W,
size: Arc<AtomicUsize>,
}
impl<W> SizeAwareWriter<W> {
fn new(inner: W, size: Arc<AtomicUsize>) -> Self {
Self {
inner,
size: size.clone(),
}
}
}
impl<W> AsyncWrite for SizeAwareWriter<W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(bytes_written)) => {
this.size.fetch_add(bytes_written, Ordering::Relaxed);
Poll::Ready(Ok(bytes_written))
}
other => other,
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}