use std::collections::HashMap;
use std::sync::Arc;
use common_wal::config::kafka::DatanodeKafkaConfig;
use rskafka::client::partition::{Compression, PartitionClient, UnknownTopicHandling};
use rskafka::client::ClientBuilder;
use rskafka::BackoffConfig;
use snafu::ResultExt;
use store_api::logstore::provider::KafkaProvider;
use tokio::sync::{Mutex, RwLock};
use crate::error::{
BuildClientSnafu, BuildPartitionClientSnafu, ResolveKafkaEndpointSnafu, Result, TlsConfigSnafu,
};
use crate::kafka::index::{GlobalIndexCollector, NoopCollector};
use crate::kafka::producer::{OrderedBatchProducer, OrderedBatchProducerRef};
pub const DEFAULT_PARTITION: i32 = 0;
pub(crate) type ClientManagerRef = Arc<ClientManager>;
#[derive(Debug, Clone)]
pub(crate) struct Client {
client: Arc<PartitionClient>,
producer: OrderedBatchProducerRef,
}
impl Client {
pub(crate) fn client(&self) -> &Arc<PartitionClient> {
&self.client
}
pub(crate) fn producer(&self) -> &OrderedBatchProducerRef {
&self.producer
}
}
#[derive(Debug)]
pub(crate) struct ClientManager {
client: rskafka::client::Client,
mutex: Mutex<()>,
instances: RwLock<HashMap<Arc<KafkaProvider>, Client>>,
global_index_collector: Option<GlobalIndexCollector>,
flush_batch_size: usize,
compression: Compression,
}
impl ClientManager {
pub(crate) async fn try_new(
config: &DatanodeKafkaConfig,
global_index_collector: Option<GlobalIndexCollector>,
) -> Result<Self> {
let backoff_config = BackoffConfig {
init_backoff: config.backoff.init,
max_backoff: config.backoff.max,
base: config.backoff.base as f64,
deadline: config.backoff.deadline,
};
let broker_endpoints = common_wal::resolve_to_ipv4(&config.connection.broker_endpoints)
.await
.context(ResolveKafkaEndpointSnafu)?;
let mut builder = ClientBuilder::new(broker_endpoints).backoff_config(backoff_config);
if let Some(sasl) = &config.connection.sasl {
builder = builder.sasl_config(sasl.config.clone().into_sasl_config());
};
if let Some(tls) = &config.connection.tls {
builder = builder.tls_config(tls.to_tls_config().await.context(TlsConfigSnafu)?)
};
let client = builder.build().await.with_context(|_| BuildClientSnafu {
broker_endpoints: config.connection.broker_endpoints.clone(),
})?;
Ok(Self {
client,
mutex: Mutex::new(()),
instances: RwLock::new(HashMap::new()),
flush_batch_size: config.max_batch_bytes.as_bytes() as usize,
compression: Compression::Lz4,
global_index_collector,
})
}
async fn try_insert(&self, provider: &Arc<KafkaProvider>) -> Result<Client> {
let _guard = self.mutex.lock().await;
let client = self.instances.read().await.get(provider).cloned();
match client {
Some(client) => Ok(client),
None => {
let client = self.try_create_client(provider).await?;
self.instances
.write()
.await
.insert(provider.clone(), client.clone());
Ok(client)
}
}
}
pub(crate) async fn get_or_insert(&self, provider: &Arc<KafkaProvider>) -> Result<Client> {
let client = self.instances.read().await.get(provider).cloned();
match client {
Some(client) => Ok(client),
None => self.try_insert(provider).await,
}
}
async fn try_create_client(&self, provider: &Arc<KafkaProvider>) -> Result<Client> {
let client = self
.client
.partition_client(
provider.topic.as_str(),
DEFAULT_PARTITION,
UnknownTopicHandling::Retry,
)
.await
.context(BuildPartitionClientSnafu {
topic: &provider.topic,
partition: DEFAULT_PARTITION,
})
.map(Arc::new)?;
let (tx, rx) = OrderedBatchProducer::channel();
let index_collector = if let Some(global_collector) = self.global_index_collector.as_ref() {
global_collector
.provider_level_index_collector(provider.clone(), tx.clone())
.await
} else {
Box::new(NoopCollector)
};
let producer = Arc::new(OrderedBatchProducer::new(
(tx, rx),
provider.clone(),
client.clone(),
self.compression,
self.flush_batch_size,
index_collector,
));
Ok(Client { client, producer })
}
pub(crate) fn global_index_collector(&self) -> Option<&GlobalIndexCollector> {
self.global_index_collector.as_ref()
}
}
#[cfg(test)]
mod tests {
use common_wal::config::kafka::common::KafkaConnectionConfig;
use common_wal::test_util::run_test_with_kafka_wal;
use tokio::sync::Barrier;
use super::*;
pub async fn create_topics<F>(
num_topics: usize,
decorator: F,
broker_endpoints: &[String],
) -> Vec<String>
where
F: Fn(usize) -> String,
{
assert!(!broker_endpoints.is_empty());
let client = ClientBuilder::new(broker_endpoints.to_vec())
.build()
.await
.unwrap();
let ctrl_client = client.controller_client().unwrap();
let (topics, tasks): (Vec<_>, Vec<_>) = (0..num_topics)
.map(|i| {
let topic = decorator(i);
let task = ctrl_client.create_topic(topic.clone(), 1, 1, 500);
(topic, task)
})
.unzip();
futures::future::try_join_all(tasks).await.unwrap();
topics
}
async fn prepare(
test_name: &str,
num_topics: usize,
broker_endpoints: Vec<String>,
) -> (ClientManager, Vec<String>) {
let topics = create_topics(
num_topics,
|i| format!("{test_name}_{}_{}", i, uuid::Uuid::new_v4()),
&broker_endpoints,
)
.await;
let config = DatanodeKafkaConfig {
connection: KafkaConnectionConfig {
broker_endpoints,
..Default::default()
},
..Default::default()
};
let manager = ClientManager::try_new(&config, None).await.unwrap();
(manager, topics)
}
#[tokio::test]
async fn test_sequential() {
run_test_with_kafka_wal(|broker_endpoints| {
Box::pin(async {
let (manager, topics) = prepare("test_sequential", 128, broker_endpoints).await;
let region_topic = (0..512)
.map(|region_id| (region_id, &topics[region_id % topics.len()]))
.collect::<HashMap<_, _>>();
for (_, topic) in region_topic {
let provider = Arc::new(KafkaProvider::new(topic.to_string()));
manager.get_or_insert(&provider).await.unwrap();
}
let client_pool = manager.instances.read().await;
let all_exist = topics.iter().all(|topic| {
let provider = Arc::new(KafkaProvider::new(topic.to_string()));
client_pool.contains_key(&provider)
});
assert!(all_exist);
})
})
.await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_parallel() {
run_test_with_kafka_wal(|broker_endpoints| {
Box::pin(async {
let (manager, topics) = prepare("test_parallel", 128, broker_endpoints).await;
let region_topic = (0..512)
.map(|region_id| (region_id, topics[region_id % topics.len()].clone()))
.collect::<HashMap<_, _>>();
let manager = Arc::new(manager);
let barrier = Arc::new(Barrier::new(region_topic.len()));
let tasks = region_topic
.into_values()
.map(|topic| {
let manager = manager.clone();
let barrier = barrier.clone();
tokio::spawn(async move {
barrier.wait().await;
let provider = Arc::new(KafkaProvider::new(topic));
assert!(manager.get_or_insert(&provider).await.is_ok());
})
})
.collect::<Vec<_>>();
futures::future::try_join_all(tasks).await.unwrap();
let client_pool = manager.instances.read().await;
let all_exist = topics.iter().all(|topic| {
let provider = Arc::new(KafkaProvider::new(topic.to_string()));
client_pool.contains_key(&provider)
});
assert!(all_exist);
})
})
.await;
}
}