1use std::sync::Arc;
16
17use api::v1::meta::heartbeat_client::HeartbeatClient;
18use api::v1::meta::{HeartbeatRequest, HeartbeatResponse, RequestHeader, Role};
19use common_grpc::channel_manager::ChannelManager;
20use common_meta::util;
21use common_telemetry::info;
22use common_telemetry::tracing_context::TracingContext;
23use snafu::{ensure, OptionExt, ResultExt};
24use tokio::sync::{mpsc, RwLock};
25use tokio_stream::wrappers::ReceiverStream;
26use tonic::codec::CompressionEncoding;
27use tonic::transport::Channel;
28use tonic::Streaming;
29
30use crate::client::ask_leader::AskLeader;
31use crate::client::Id;
32use crate::error;
33use crate::error::{InvalidResponseHeaderSnafu, Result};
34
35pub struct HeartbeatSender {
36 id: Id,
37 role: Role,
38 sender: mpsc::Sender<HeartbeatRequest>,
39}
40
41impl HeartbeatSender {
42 #[inline]
43 fn new(id: Id, role: Role, sender: mpsc::Sender<HeartbeatRequest>) -> Self {
44 Self { id, role, sender }
45 }
46
47 #[inline]
48 pub fn id(&self) -> Id {
49 self.id
50 }
51
52 #[inline]
53 pub async fn send(&self, mut req: HeartbeatRequest) -> Result<()> {
54 req.set_header(
55 self.id,
56 self.role,
57 TracingContext::from_current_span().to_w3c(),
58 );
59 self.sender.send(req).await.map_err(|e| {
60 error::SendHeartbeatSnafu {
61 err_msg: e.to_string(),
62 }
63 .build()
64 })
65 }
66}
67
68#[derive(Debug)]
69pub struct HeartbeatStream {
70 id: Id,
71 stream: Streaming<HeartbeatResponse>,
72}
73
74impl HeartbeatStream {
75 #[inline]
76 fn new(id: Id, stream: Streaming<HeartbeatResponse>) -> Self {
77 Self { id, stream }
78 }
79
80 #[inline]
81 pub fn id(&self) -> Id {
82 self.id
83 }
84
85 #[inline]
87 pub async fn message(&mut self) -> Result<Option<HeartbeatResponse>> {
88 let res = self.stream.message().await.map_err(error::Error::from);
89 if let Ok(Some(heartbeat)) = &res {
90 util::check_response_header(heartbeat.header.as_ref())
91 .context(InvalidResponseHeaderSnafu)?;
92 }
93 res
94 }
95}
96
97#[derive(Clone, Debug)]
98pub struct Client {
99 inner: Arc<RwLock<Inner>>,
100}
101
102impl Client {
103 pub fn new(id: Id, role: Role, channel_manager: ChannelManager, max_retry: usize) -> Self {
104 let inner = Arc::new(RwLock::new(Inner::new(
105 id,
106 role,
107 channel_manager,
108 max_retry,
109 )));
110 Self { inner }
111 }
112
113 pub async fn start<U, A>(&mut self, urls: A) -> Result<()>
114 where
115 U: AsRef<str>,
116 A: AsRef<[U]>,
117 {
118 let mut inner = self.inner.write().await;
119 inner.start(urls).await
120 }
121
122 pub async fn ask_leader(&mut self) -> Result<String> {
123 let inner = self.inner.read().await;
124 inner.ask_leader().await
125 }
126
127 pub async fn heartbeat(&mut self) -> Result<(HeartbeatSender, HeartbeatStream)> {
128 let inner = self.inner.read().await;
129 inner.ask_leader().await?;
130 inner.heartbeat().await
131 }
132}
133
134#[derive(Debug)]
135struct Inner {
136 id: Id,
137 role: Role,
138 channel_manager: ChannelManager,
139 ask_leader: Option<AskLeader>,
140 max_retry: usize,
141}
142
143impl Inner {
144 fn new(id: Id, role: Role, channel_manager: ChannelManager, max_retry: usize) -> Self {
145 Self {
146 id,
147 role,
148 channel_manager,
149 ask_leader: None,
150 max_retry,
151 }
152 }
153
154 async fn start<U, A>(&mut self, urls: A) -> Result<()>
155 where
156 U: AsRef<str>,
157 A: AsRef<[U]>,
158 {
159 ensure!(
160 !self.is_started(),
161 error::IllegalGrpcClientStateSnafu {
162 err_msg: "Heartbeat client already started"
163 }
164 );
165
166 let peers = urls
167 .as_ref()
168 .iter()
169 .map(|url| url.as_ref().to_string())
170 .collect::<Vec<_>>();
171 self.ask_leader = Some(AskLeader::new(
172 self.id,
173 self.role,
174 peers,
175 self.channel_manager.clone(),
176 self.max_retry,
177 ));
178
179 Ok(())
180 }
181
182 async fn ask_leader(&self) -> Result<String> {
183 ensure!(
184 self.is_started(),
185 error::IllegalGrpcClientStateSnafu {
186 err_msg: "Heartbeat client not start"
187 }
188 );
189
190 self.ask_leader.as_ref().unwrap().ask_leader().await
191 }
192
193 async fn heartbeat(&self) -> Result<(HeartbeatSender, HeartbeatStream)> {
194 ensure!(
195 self.is_started(),
196 error::IllegalGrpcClientStateSnafu {
197 err_msg: "Heartbeat client not start"
198 }
199 );
200
201 let leader_addr = self
202 .ask_leader
203 .as_ref()
204 .unwrap()
205 .get_leader()
206 .context(error::NoLeaderSnafu)?;
207 let mut leader = self.make_client(&leader_addr)?;
208
209 let (sender, receiver) = mpsc::channel::<HeartbeatRequest>(128);
210
211 let header = RequestHeader::new(
212 self.id,
213 self.role,
214 TracingContext::from_current_span().to_w3c(),
215 );
216 let handshake = HeartbeatRequest {
217 header: Some(header),
218 ..Default::default()
219 };
220 sender.send(handshake).await.map_err(|e| {
221 error::SendHeartbeatSnafu {
222 err_msg: e.to_string(),
223 }
224 .build()
225 })?;
226 let receiver = ReceiverStream::new(receiver);
227
228 let mut stream = leader
229 .heartbeat(receiver)
230 .await
231 .map_err(error::Error::from)?
232 .into_inner();
233
234 let res = stream
235 .message()
236 .await
237 .map_err(error::Error::from)?
238 .context(error::CreateHeartbeatStreamSnafu)?;
239
240 info!(
241 "Success to create heartbeat stream to server: {}, response: {:#?}",
242 leader_addr, res
243 );
244
245 Ok((
246 HeartbeatSender::new(self.id, self.role, sender),
247 HeartbeatStream::new(self.id, stream),
248 ))
249 }
250
251 fn make_client(&self, addr: impl AsRef<str>) -> Result<HeartbeatClient<Channel>> {
252 let channel = self
253 .channel_manager
254 .get(addr)
255 .context(error::CreateChannelSnafu)?;
256
257 Ok(HeartbeatClient::new(channel)
258 .accept_compressed(CompressionEncoding::Zstd)
259 .accept_compressed(CompressionEncoding::Gzip)
260 .send_compressed(CompressionEncoding::Zstd))
261 }
262
263 #[inline]
264 pub(crate) fn is_started(&self) -> bool {
265 self.ask_leader.is_some()
266 }
267}
268
269#[cfg(test)]
270mod test {
271 use super::*;
272
273 #[tokio::test]
274 async fn test_already_start() {
275 let mut client = Client::new(0, Role::Datanode, ChannelManager::default(), 3);
276 client
277 .start(&["127.0.0.1:1000", "127.0.0.1:1001"])
278 .await
279 .unwrap();
280 let res = client.start(&["127.0.0.1:1002"]).await;
281 assert!(res.is_err());
282 assert!(matches!(
283 res.err(),
284 Some(error::Error::IllegalGrpcClientState { .. })
285 ));
286 }
287
288 #[tokio::test]
289 async fn test_heartbeat_stream() {
290 let (sender, mut receiver) = mpsc::channel::<HeartbeatRequest>(100);
291 let sender = HeartbeatSender::new(8, Role::Datanode, sender);
292 let _handle = tokio::spawn(async move {
293 for _ in 0..10 {
294 sender.send(HeartbeatRequest::default()).await.unwrap();
295 }
296 });
297 while let Some(req) = receiver.recv().await {
298 let header = req.header.unwrap();
299 assert_eq!(8, header.member_id);
300 }
301 }
302}