1use std::sync::Arc;
16
17use api::v1::meta::heartbeat_client::HeartbeatClient;
18use api::v1::meta::{HeartbeatRequest, HeartbeatResponse, RequestHeader, Role};
19use common_grpc::channel_manager::ChannelManager;
20use common_meta::util;
21use common_telemetry::info;
22use common_telemetry::tracing_context::TracingContext;
23use snafu::{ensure, OptionExt, ResultExt};
24use tokio::sync::{mpsc, RwLock};
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::codec::CompressionEncoding;
27use tonic::transport::Channel;
28use tonic::Streaming;
29
30use crate::client::ask_leader::AskLeader;
31use crate::client::{Id, LeaderProviderRef};
32use crate::error;
33use crate::error::{InvalidResponseHeaderSnafu, Result};
34
35pub struct HeartbeatSender {
36 id: Id,
37 role: Role,
38 sender: mpsc::Sender<HeartbeatRequest>,
39}
40
41impl HeartbeatSender {
42 #[inline]
43 fn new(id: Id, role: Role, sender: mpsc::Sender<HeartbeatRequest>) -> Self {
44 Self { id, role, sender }
45 }
46
47 #[inline]
48 pub fn id(&self) -> Id {
49 self.id
50 }
51
52 #[inline]
53 pub async fn send(&self, mut req: HeartbeatRequest) -> Result<()> {
54 req.set_header(
55 self.id,
56 self.role,
57 TracingContext::from_current_span().to_w3c(),
58 );
59 self.sender.send(req).await.map_err(|e| {
60 error::SendHeartbeatSnafu {
61 err_msg: e.to_string(),
62 }
63 .build()
64 })
65 }
66}
67
68#[derive(Debug)]
69pub struct HeartbeatStream {
70 id: Id,
71 stream: Streaming<HeartbeatResponse>,
72}
73
74impl HeartbeatStream {
75 #[inline]
76 fn new(id: Id, stream: Streaming<HeartbeatResponse>) -> Self {
77 Self { id, stream }
78 }
79
80 #[inline]
81 pub fn id(&self) -> Id {
82 self.id
83 }
84
85 #[inline]
87 pub async fn message(&mut self) -> Result<Option<HeartbeatResponse>> {
88 let res = self.stream.message().await.map_err(error::Error::from);
89 if let Ok(Some(heartbeat)) = &res {
90 util::check_response_header(heartbeat.header.as_ref())
91 .context(InvalidResponseHeaderSnafu)?;
92 }
93 res
94 }
95}
96
97#[derive(Clone, Debug)]
98pub struct Client {
99 inner: Arc<RwLock<Inner>>,
100}
101
102impl Client {
103 pub fn new(id: Id, role: Role, channel_manager: ChannelManager, max_retry: usize) -> Self {
104 let inner = Arc::new(RwLock::new(Inner::new(
105 id,
106 role,
107 channel_manager,
108 max_retry,
109 )));
110 Self { inner }
111 }
112
113 pub async fn start<U, A>(&mut self, urls: A) -> Result<()>
114 where
115 U: AsRef<str>,
116 A: AsRef<[U]>,
117 {
118 let mut inner = self.inner.write().await;
119 inner.start(urls)
120 }
121
122 pub(crate) async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()> {
124 let mut inner = self.inner.write().await;
125 inner.start_with(leader_provider)
126 }
127
128 pub async fn ask_leader(&mut self) -> Result<String> {
129 let inner = self.inner.read().await;
130 inner.ask_leader().await
131 }
132
133 pub async fn heartbeat(&mut self) -> Result<(HeartbeatSender, HeartbeatStream)> {
134 let inner = self.inner.read().await;
135 inner.ask_leader().await?;
136 inner.heartbeat().await
137 }
138}
139
140#[derive(Debug)]
141struct Inner {
142 id: Id,
143 role: Role,
144 channel_manager: ChannelManager,
145 leader_provider: Option<LeaderProviderRef>,
146 max_retry: usize,
147}
148
149impl Inner {
150 fn new(id: Id, role: Role, channel_manager: ChannelManager, max_retry: usize) -> Self {
151 Self {
152 id,
153 role,
154 channel_manager,
155 leader_provider: None,
156 max_retry,
157 }
158 }
159
160 fn start_with(&mut self, leader_provider: LeaderProviderRef) -> Result<()> {
161 ensure!(
162 !self.is_started(),
163 error::IllegalGrpcClientStateSnafu {
164 err_msg: "Heartbeat client already started"
165 }
166 );
167 self.leader_provider = Some(leader_provider);
168 Ok(())
169 }
170
171 fn start<U, A>(&mut self, urls: A) -> Result<()>
172 where
173 U: AsRef<str>,
174 A: AsRef<[U]>,
175 {
176 let peers = urls
177 .as_ref()
178 .iter()
179 .map(|url| url.as_ref().to_string())
180 .collect::<Vec<_>>();
181 let ask_leader = AskLeader::new(
182 self.id,
183 self.role,
184 peers,
185 self.channel_manager.clone(),
186 self.max_retry,
187 );
188 self.start_with(Arc::new(ask_leader))
189 }
190
191 async fn ask_leader(&self) -> Result<String> {
192 let Some(leader_provider) = self.leader_provider.as_ref() else {
193 return error::IllegalGrpcClientStateSnafu {
194 err_msg: "not started",
195 }
196 .fail();
197 };
198 leader_provider.ask_leader().await
199 }
200
201 async fn heartbeat(&self) -> Result<(HeartbeatSender, HeartbeatStream)> {
202 ensure!(
203 self.is_started(),
204 error::IllegalGrpcClientStateSnafu {
205 err_msg: "Heartbeat client not start"
206 }
207 );
208
209 let leader_addr = self
210 .leader_provider
211 .as_ref()
212 .unwrap()
213 .leader()
214 .context(error::NoLeaderSnafu)?;
215 let mut leader = self.make_client(&leader_addr)?;
216
217 let (sender, receiver) = mpsc::channel::<HeartbeatRequest>(128);
218
219 let header = RequestHeader::new(
220 self.id,
221 self.role,
222 TracingContext::from_current_span().to_w3c(),
223 );
224 let handshake = HeartbeatRequest {
225 header: Some(header),
226 ..Default::default()
227 };
228 sender.send(handshake).await.map_err(|e| {
229 error::SendHeartbeatSnafu {
230 err_msg: e.to_string(),
231 }
232 .build()
233 })?;
234 let receiver = ReceiverStream::new(receiver);
235
236 let mut stream = leader
237 .heartbeat(receiver)
238 .await
239 .map_err(error::Error::from)?
240 .into_inner();
241
242 let res = stream
243 .message()
244 .await
245 .map_err(error::Error::from)?
246 .context(error::CreateHeartbeatStreamSnafu)?;
247
248 info!(
249 "Success to create heartbeat stream to server: {}, response: {:#?}",
250 leader_addr, res
251 );
252
253 Ok((
254 HeartbeatSender::new(self.id, self.role, sender),
255 HeartbeatStream::new(self.id, stream),
256 ))
257 }
258
259 fn make_client(&self, addr: impl AsRef<str>) -> Result<HeartbeatClient<Channel>> {
260 let channel = self
261 .channel_manager
262 .get(addr)
263 .context(error::CreateChannelSnafu)?;
264
265 Ok(HeartbeatClient::new(channel)
266 .accept_compressed(CompressionEncoding::Zstd)
267 .accept_compressed(CompressionEncoding::Gzip)
268 .send_compressed(CompressionEncoding::Zstd))
269 }
270
271 #[inline]
272 pub(crate) fn is_started(&self) -> bool {
273 self.leader_provider.is_some()
274 }
275}
276
277#[cfg(test)]
278mod test {
279 use super::*;
280
281 #[tokio::test]
282 async fn test_already_start() {
283 let mut client = Client::new(0, Role::Datanode, ChannelManager::default(), 3);
284 client
285 .start(&["127.0.0.1:1000", "127.0.0.1:1001"])
286 .await
287 .unwrap();
288 let res = client.start(&["127.0.0.1:1002"]).await;
289 assert!(res.is_err());
290 assert!(matches!(
291 res.err(),
292 Some(error::Error::IllegalGrpcClientState { .. })
293 ));
294 }
295
296 #[tokio::test]
297 async fn test_heartbeat_stream() {
298 let (sender, mut receiver) = mpsc::channel::<HeartbeatRequest>(100);
299 let sender = HeartbeatSender::new(8, Role::Datanode, sender);
300 let _handle = tokio::spawn(async move {
301 for _ in 0..10 {
302 sender.send(HeartbeatRequest::default()).await.unwrap();
303 }
304 });
305 while let Some(req) = receiver.recv().await {
306 let header = req.header.unwrap();
307 assert_eq!(8, header.member_id);
308 }
309 }
310}