servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::Mutex as StdMutex;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{debug, error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tonic::codegen::Service;
48use tower::{Layer, ServiceBuilder};
49use tower_http::compression::CompressionLayer;
50use tower_http::cors::{AllowOrigin, Any, CorsLayer};
51use tower_http::decompression::RequestDecompressionLayer;
52use tower_http::trace::TraceLayer;
53
54use self::authorize::AuthState;
55use self::result::table_result::TableResponse;
56use crate::configurator::HttpConfiguratorRef;
57use crate::elasticsearch;
58use crate::error::{
59    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
60    OtherSnafu, Result,
61};
62use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
63use crate::http::otlp::OtlpState;
64use crate::http::prom_store::PromStoreState;
65use crate::http::prometheus::{
66    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
67    range_query, series_query,
68};
69use crate::http::result::arrow_result::ArrowResponse;
70use crate::http::result::csv_result::CsvResponse;
71use crate::http::result::error_result::ErrorResponse;
72use crate::http::result::greptime_result_v1::GreptimedbV1Response;
73use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
74use crate::http::result::json_result::JsonResponse;
75use crate::http::result::null_result::NullResponse;
76use crate::interceptor::LogIngestInterceptorRef;
77use crate::metrics::http_metrics_layer;
78use crate::metrics_handler::MetricsHandler;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_memory_limiter::ServerMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod hints;
117mod read_preference;
118#[cfg(any(test, feature = "testing"))]
119pub mod test_helpers;
120
121pub const HTTP_API_VERSION: &str = "v1";
122pub const HTTP_API_PREFIX: &str = "/v1/";
123/// Default http body limit (64M).
124const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
125
126/// Authorization header
127pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
128
129// TODO(fys): This is a temporary workaround, it will be improved later
130pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
131
132#[derive(Default)]
133pub struct HttpServer {
134    router: StdMutex<Router>,
135    shutdown_tx: Mutex<Option<Sender<()>>>,
136    user_provider: Option<UserProviderRef>,
137    memory_limiter: ServerMemoryLimiter,
138
139    // plugins
140    plugins: Plugins,
141
142    // server configs
143    options: HttpOptions,
144    bind_addr: Option<SocketAddr>,
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(default)]
149pub struct HttpOptions {
150    pub addr: String,
151
152    #[serde(with = "humantime_serde")]
153    pub timeout: Duration,
154
155    #[serde(skip)]
156    pub disable_dashboard: bool,
157
158    pub body_limit: ReadableSize,
159
160    /// Validation mode while decoding Prometheus remote write requests.
161    pub prom_validation_mode: PromValidationMode,
162
163    pub cors_allowed_origins: Vec<String>,
164
165    pub enable_cors: bool,
166}
167
168#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum PromValidationMode {
171    /// Force UTF8 validation
172    Strict,
173    /// Allow lossy UTF8 strings
174    Lossy,
175    /// Do not validate UTF8 strings.
176    Unchecked,
177}
178
179impl PromValidationMode {
180    /// Decodes provided bytes to [String] with optional UTF-8 validation.
181    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
182        let result = match self {
183            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
184                Ok(s) => s,
185                Err(e) => {
186                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
187                    return Err(DecodeError::new("invalid utf-8"));
188                }
189            },
190            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
191            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
192        };
193        Ok(result)
194    }
195}
196
197impl Default for HttpOptions {
198    fn default() -> Self {
199        Self {
200            addr: "127.0.0.1:4000".to_string(),
201            timeout: Duration::from_secs(0),
202            disable_dashboard: false,
203            body_limit: DEFAULT_BODY_LIMIT,
204            cors_allowed_origins: Vec::new(),
205            enable_cors: true,
206            prom_validation_mode: PromValidationMode::Strict,
207        }
208    }
209}
210
211#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
212pub struct ColumnSchema {
213    name: String,
214    data_type: String,
215}
216
217impl ColumnSchema {
218    pub fn new(name: String, data_type: String) -> ColumnSchema {
219        ColumnSchema { name, data_type }
220    }
221}
222
223#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
224pub struct OutputSchema {
225    column_schemas: Vec<ColumnSchema>,
226}
227
228impl OutputSchema {
229    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
230        OutputSchema {
231            column_schemas: columns,
232        }
233    }
234}
235
236impl From<SchemaRef> for OutputSchema {
237    fn from(schema: SchemaRef) -> OutputSchema {
238        OutputSchema {
239            column_schemas: schema
240                .column_schemas()
241                .iter()
242                .map(|cs| ColumnSchema {
243                    name: cs.name.clone(),
244                    data_type: cs.data_type.name(),
245                })
246                .collect(),
247        }
248    }
249}
250
251#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
252pub struct HttpRecordsOutput {
253    schema: OutputSchema,
254    rows: Vec<Vec<Value>>,
255    // total_rows is equal to rows.len() in most cases,
256    // the Dashboard query result may be truncated, so we need to return the total_rows.
257    #[serde(default)]
258    total_rows: usize,
259
260    // plan level execution metrics
261    #[serde(skip_serializing_if = "HashMap::is_empty")]
262    #[serde(default)]
263    metrics: HashMap<String, Value>,
264}
265
266impl HttpRecordsOutput {
267    pub fn num_rows(&self) -> usize {
268        self.rows.len()
269    }
270
271    pub fn num_cols(&self) -> usize {
272        self.schema.column_schemas.len()
273    }
274
275    pub fn schema(&self) -> &OutputSchema {
276        &self.schema
277    }
278
279    pub fn rows(&self) -> &Vec<Vec<Value>> {
280        &self.rows
281    }
282}
283
284impl HttpRecordsOutput {
285    pub fn try_new(
286        schema: SchemaRef,
287        recordbatches: Vec<RecordBatch>,
288    ) -> std::result::Result<HttpRecordsOutput, Error> {
289        if recordbatches.is_empty() {
290            Ok(HttpRecordsOutput {
291                schema: OutputSchema::from(schema),
292                rows: vec![],
293                total_rows: 0,
294                metrics: Default::default(),
295            })
296        } else {
297            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
298            let mut rows = Vec::with_capacity(num_rows);
299
300            for recordbatch in recordbatches {
301                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
302                writer.write(recordbatch, &mut rows)?;
303            }
304
305            Ok(HttpRecordsOutput {
306                schema: OutputSchema::from(schema),
307                total_rows: rows.len(),
308                rows,
309                metrics: Default::default(),
310            })
311        }
312    }
313}
314
315#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
316#[serde(rename_all = "lowercase")]
317pub enum GreptimeQueryOutput {
318    AffectedRows(usize),
319    Records(HttpRecordsOutput),
320}
321
322/// It allows the results of SQL queries to be presented in different formats.
323#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
324pub enum ResponseFormat {
325    Arrow,
326    // (with_names, with_types)
327    Csv(bool, bool),
328    Table,
329    #[default]
330    GreptimedbV1,
331    InfluxdbV1,
332    Json,
333    Null,
334}
335
336impl ResponseFormat {
337    pub fn parse(s: &str) -> Option<Self> {
338        match s {
339            "arrow" => Some(ResponseFormat::Arrow),
340            "csv" => Some(ResponseFormat::Csv(false, false)),
341            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
342            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
343            "table" => Some(ResponseFormat::Table),
344            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
345            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
346            "json" => Some(ResponseFormat::Json),
347            "null" => Some(ResponseFormat::Null),
348            _ => None,
349        }
350    }
351
352    pub fn as_str(&self) -> &'static str {
353        match self {
354            ResponseFormat::Arrow => "arrow",
355            ResponseFormat::Csv(_, _) => "csv",
356            ResponseFormat::Table => "table",
357            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
358            ResponseFormat::InfluxdbV1 => "influxdb_v1",
359            ResponseFormat::Json => "json",
360            ResponseFormat::Null => "null",
361        }
362    }
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum Epoch {
367    Nanosecond,
368    Microsecond,
369    Millisecond,
370    Second,
371}
372
373impl Epoch {
374    pub fn parse(s: &str) -> Option<Epoch> {
375        // Both u and µ indicate microseconds.
376        // epoch = [ns,u,µ,ms,s],
377        // For details, see the Influxdb documents.
378        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
379        match s {
380            "ns" => Some(Epoch::Nanosecond),
381            "u" | "µ" => Some(Epoch::Microsecond),
382            "ms" => Some(Epoch::Millisecond),
383            "s" => Some(Epoch::Second),
384            _ => None, // just returns None for other cases
385        }
386    }
387
388    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
389        match self {
390            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
391            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
392            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
393            Epoch::Second => ts.convert_to(TimeUnit::Second),
394        }
395    }
396}
397
398impl Display for Epoch {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self {
401            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
402            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
403            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
404            Epoch::Second => write!(f, "Epoch::Second"),
405        }
406    }
407}
408
409#[derive(Serialize, Deserialize, Debug)]
410pub enum HttpResponse {
411    Arrow(ArrowResponse),
412    Csv(CsvResponse),
413    Table(TableResponse),
414    Error(ErrorResponse),
415    GreptimedbV1(GreptimedbV1Response),
416    InfluxdbV1(InfluxdbV1Response),
417    Json(JsonResponse),
418    Null(NullResponse),
419}
420
421impl HttpResponse {
422    pub fn with_execution_time(self, execution_time: u64) -> Self {
423        match self {
424            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
425            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
426            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
427            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
428            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
429            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
430            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
431            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
432        }
433    }
434
435    pub fn with_limit(self, limit: usize) -> Self {
436        match self {
437            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
438            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
439            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
440            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
441            _ => self,
442        }
443    }
444}
445
446pub fn process_with_limit(
447    mut outputs: Vec<GreptimeQueryOutput>,
448    limit: usize,
449) -> Vec<GreptimeQueryOutput> {
450    outputs
451        .drain(..)
452        .map(|data| match data {
453            GreptimeQueryOutput::Records(mut records) => {
454                if records.rows.len() > limit {
455                    records.rows.truncate(limit);
456                    records.total_rows = limit;
457                }
458                GreptimeQueryOutput::Records(records)
459            }
460            _ => data,
461        })
462        .collect()
463}
464
465impl IntoResponse for HttpResponse {
466    fn into_response(self) -> Response {
467        match self {
468            HttpResponse::Arrow(resp) => resp.into_response(),
469            HttpResponse::Csv(resp) => resp.into_response(),
470            HttpResponse::Table(resp) => resp.into_response(),
471            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
472            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
473            HttpResponse::Json(resp) => resp.into_response(),
474            HttpResponse::Null(resp) => resp.into_response(),
475            HttpResponse::Error(resp) => resp.into_response(),
476        }
477    }
478}
479
480impl From<ArrowResponse> for HttpResponse {
481    fn from(value: ArrowResponse) -> Self {
482        HttpResponse::Arrow(value)
483    }
484}
485
486impl From<CsvResponse> for HttpResponse {
487    fn from(value: CsvResponse) -> Self {
488        HttpResponse::Csv(value)
489    }
490}
491
492impl From<TableResponse> for HttpResponse {
493    fn from(value: TableResponse) -> Self {
494        HttpResponse::Table(value)
495    }
496}
497
498impl From<ErrorResponse> for HttpResponse {
499    fn from(value: ErrorResponse) -> Self {
500        HttpResponse::Error(value)
501    }
502}
503
504impl From<GreptimedbV1Response> for HttpResponse {
505    fn from(value: GreptimedbV1Response) -> Self {
506        HttpResponse::GreptimedbV1(value)
507    }
508}
509
510impl From<InfluxdbV1Response> for HttpResponse {
511    fn from(value: InfluxdbV1Response) -> Self {
512        HttpResponse::InfluxdbV1(value)
513    }
514}
515
516impl From<JsonResponse> for HttpResponse {
517    fn from(value: JsonResponse) -> Self {
518        HttpResponse::Json(value)
519    }
520}
521
522impl From<NullResponse> for HttpResponse {
523    fn from(value: NullResponse) -> Self {
524        HttpResponse::Null(value)
525    }
526}
527
528#[derive(Clone)]
529pub struct ApiState {
530    pub sql_handler: ServerSqlQueryHandlerRef,
531}
532
533#[derive(Clone)]
534pub struct GreptimeOptionsConfigState {
535    pub greptime_config_options: String,
536}
537
538pub struct HttpServerBuilder {
539    options: HttpOptions,
540    plugins: Plugins,
541    user_provider: Option<UserProviderRef>,
542    router: Router,
543    memory_limiter: ServerMemoryLimiter,
544}
545
546impl HttpServerBuilder {
547    pub fn new(options: HttpOptions) -> Self {
548        Self {
549            options,
550            plugins: Plugins::default(),
551            user_provider: None,
552            router: Router::new(),
553            memory_limiter: ServerMemoryLimiter::default(),
554        }
555    }
556
557    /// Set a global memory limiter for all server protocols.
558    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
559        self.memory_limiter = limiter;
560        self
561    }
562
563    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
564        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
565
566        Self {
567            router: self
568                .router
569                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
570            ..self
571        }
572    }
573
574    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
575        let logs_router = HttpServer::route_logs(logs_handler);
576
577        Self {
578            router: self
579                .router
580                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
581            ..self
582        }
583    }
584
585    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
586        Self {
587            router: self.router.nest(
588                &format!("/{HTTP_API_VERSION}/opentsdb"),
589                HttpServer::route_opentsdb(handler),
590            ),
591            ..self
592        }
593    }
594
595    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
596        Self {
597            router: self.router.nest(
598                &format!("/{HTTP_API_VERSION}/influxdb"),
599                HttpServer::route_influxdb(handler),
600            ),
601            ..self
602        }
603    }
604
605    pub fn with_prom_handler(
606        self,
607        handler: PromStoreProtocolHandlerRef,
608        pipeline_handler: Option<PipelineHandlerRef>,
609        prom_store_with_metric_engine: bool,
610        prom_validation_mode: PromValidationMode,
611    ) -> Self {
612        let state = PromStoreState {
613            prom_store_handler: handler,
614            pipeline_handler,
615            prom_store_with_metric_engine,
616            prom_validation_mode,
617        };
618
619        Self {
620            router: self.router.nest(
621                &format!("/{HTTP_API_VERSION}/prometheus"),
622                HttpServer::route_prom(state),
623            ),
624            ..self
625        }
626    }
627
628    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
629        Self {
630            router: self.router.nest(
631                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
632                HttpServer::route_prometheus(handler),
633            ),
634            ..self
635        }
636    }
637
638    pub fn with_otlp_handler(
639        self,
640        handler: OpenTelemetryProtocolHandlerRef,
641        with_metric_engine: bool,
642    ) -> Self {
643        Self {
644            router: self.router.nest(
645                &format!("/{HTTP_API_VERSION}/otlp"),
646                HttpServer::route_otlp(handler, with_metric_engine),
647            ),
648            ..self
649        }
650    }
651
652    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
653        Self {
654            user_provider: Some(user_provider),
655            ..self
656        }
657    }
658
659    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
660        Self {
661            router: self.router.merge(HttpServer::route_metrics(handler)),
662            ..self
663        }
664    }
665
666    pub fn with_log_ingest_handler(
667        self,
668        handler: PipelineHandlerRef,
669        validator: Option<LogValidatorRef>,
670        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
671    ) -> Self {
672        let log_state = LogState {
673            log_handler: handler,
674            log_validator: validator,
675            ingest_interceptor,
676        };
677
678        let router = self.router.nest(
679            &format!("/{HTTP_API_VERSION}"),
680            HttpServer::route_pipelines(log_state.clone()),
681        );
682        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
683        let router = router.nest(
684            &format!("/{HTTP_API_VERSION}/events"),
685            #[allow(deprecated)]
686            HttpServer::route_log_deprecated(log_state.clone()),
687        );
688
689        let router = router.nest(
690            &format!("/{HTTP_API_VERSION}/loki"),
691            HttpServer::route_loki(log_state.clone()),
692        );
693
694        let router = router.nest(
695            &format!("/{HTTP_API_VERSION}/elasticsearch"),
696            HttpServer::route_elasticsearch(log_state.clone()),
697        );
698
699        let router = router.nest(
700            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
701            Router::new()
702                .route("/", routing::get(elasticsearch::handle_get_version))
703                .with_state(log_state),
704        );
705
706        Self { router, ..self }
707    }
708
709    pub fn with_plugins(self, plugins: Plugins) -> Self {
710        Self { plugins, ..self }
711    }
712
713    pub fn with_greptime_config_options(self, opts: String) -> Self {
714        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
715            greptime_config_options: opts,
716        });
717
718        Self {
719            router: self.router.merge(config_router),
720            ..self
721        }
722    }
723
724    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
725        Self {
726            router: self.router.nest(
727                &format!("/{HTTP_API_VERSION}/jaeger"),
728                HttpServer::route_jaeger(handler),
729            ),
730            ..self
731        }
732    }
733
734    pub fn with_extra_router(self, router: Router) -> Self {
735        Self {
736            router: self.router.merge(router),
737            ..self
738        }
739    }
740
741    pub fn add_layer<L>(self, layer: L) -> Self
742    where
743        L: Layer<Route> + Clone + Send + Sync + 'static,
744        L::Service: Service<Request> + Clone + Send + Sync + 'static,
745        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
746        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
747        <L::Service as Service<Request>>::Future: Send + 'static,
748    {
749        Self {
750            router: self.router.layer(layer),
751            ..self
752        }
753    }
754
755    pub fn build(self) -> HttpServer {
756        HttpServer {
757            options: self.options,
758            user_provider: self.user_provider,
759            shutdown_tx: Mutex::new(None),
760            plugins: self.plugins,
761            router: StdMutex::new(self.router),
762            bind_addr: None,
763            memory_limiter: self.memory_limiter,
764        }
765    }
766}
767
768impl HttpServer {
769    /// Gets the router and adds necessary root routes (health, status, dashboard).
770    pub fn make_app(&self) -> Router {
771        let mut router = {
772            let router = self.router.lock().unwrap();
773            router.clone()
774        };
775
776        router = router
777            .route(
778                "/health",
779                routing::get(handler::health).post(handler::health),
780            )
781            .route(
782                &format!("/{HTTP_API_VERSION}/health"),
783                routing::get(handler::health).post(handler::health),
784            )
785            .route(
786                "/ready",
787                routing::get(handler::health).post(handler::health),
788            );
789
790        router = router.route("/status", routing::get(handler::status));
791
792        #[cfg(feature = "dashboard")]
793        {
794            if !self.options.disable_dashboard {
795                info!("Enable dashboard service at '/dashboard'");
796                // redirect /dashboard to /dashboard/
797                router = router.route(
798                    "/dashboard",
799                    routing::get(|uri: axum::http::uri::Uri| async move {
800                        let path = uri.path();
801                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
802
803                        let new_uri = format!("{}/{}", path, query);
804                        axum::response::Redirect::permanent(&new_uri)
805                    }),
806                );
807
808                // "/dashboard" and "/dashboard/" are two different paths in Axum.
809                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
810                // So we explicitly route "/dashboard/" here.
811                router = router
812                    .route(
813                        "/dashboard/",
814                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
815                    )
816                    .route(
817                        "/dashboard/{*x}",
818                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
819                    );
820            }
821        }
822
823        // Add a layer to collect HTTP metrics for axum.
824        router = router.route_layer(middleware::from_fn(http_metrics_layer));
825
826        router
827    }
828
829    /// Attaches middlewares and debug routes to the router.
830    /// Callers should call this method after [HttpServer::make_app()].
831    pub fn build(&self, router: Router) -> Result<Router> {
832        let timeout_layer = if self.options.timeout != Duration::default() {
833            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
834        } else {
835            info!("HTTP server timeout is disabled");
836            None
837        };
838        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
839            Some(
840                ServiceBuilder::new()
841                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
842            )
843        } else {
844            info!("HTTP server body limit is disabled");
845            None
846        };
847        let cors_layer = if self.options.enable_cors {
848            Some(
849                CorsLayer::new()
850                    .allow_methods([
851                        Method::GET,
852                        Method::POST,
853                        Method::PUT,
854                        Method::DELETE,
855                        Method::HEAD,
856                    ])
857                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
858                        AllowOrigin::from(Any)
859                    } else {
860                        AllowOrigin::from(
861                            self.options
862                                .cors_allowed_origins
863                                .iter()
864                                .map(|s| {
865                                    HeaderValue::from_str(s.as_str())
866                                        .context(InvalidHeaderValueSnafu)
867                                })
868                                .collect::<Result<Vec<HeaderValue>>>()?,
869                        )
870                    })
871                    .allow_headers(Any),
872            )
873        } else {
874            info!("HTTP server cross-origin is disabled");
875            None
876        };
877
878        Ok(router
879            // middlewares
880            .layer(
881                ServiceBuilder::new()
882                    // disable on failure tracing. because printing out isn't very helpful,
883                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
884                    .layer(TraceLayer::new_for_http().on_failure(()))
885                    .option_layer(cors_layer)
886                    .option_layer(timeout_layer)
887                    .option_layer(body_limit_layer)
888                    // memory limit layer - must be before body is consumed
889                    .layer(middleware::from_fn_with_state(
890                        self.memory_limiter.clone(),
891                        memory_limit::memory_limit_middleware,
892                    ))
893                    // auth layer
894                    .layer(middleware::from_fn_with_state(
895                        AuthState::new(self.user_provider.clone()),
896                        authorize::check_http_auth,
897                    ))
898                    .layer(middleware::from_fn(hints::extract_hints))
899                    .layer(middleware::from_fn(
900                        read_preference::extract_read_preference,
901                    )),
902            )
903            // Handlers for debug, we don't expect a timeout.
904            .nest(
905                "/debug",
906                Router::new()
907                    // handler for changing log level dynamically
908                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
909                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
910                    .nest(
911                        "/prof",
912                        Router::new()
913                            .route("/cpu", routing::post(pprof::pprof_handler))
914                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
915                            .route(
916                                "/mem/activate",
917                                routing::post(mem_prof::activate_heap_prof_handler),
918                            )
919                            .route(
920                                "/mem/deactivate",
921                                routing::post(mem_prof::deactivate_heap_prof_handler),
922                            )
923                            .route(
924                                "/mem/status",
925                                routing::get(mem_prof::heap_prof_status_handler),
926                            ) // jemalloc gdump flag status and toggle
927                            .route(
928                                "/mem/gdump",
929                                routing::get(mem_prof::gdump_status_handler)
930                                    .post(mem_prof::gdump_toggle_handler),
931                            ),
932                    ),
933            ))
934    }
935
936    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
937        Router::new()
938            .route("/metrics", routing::get(handler::metrics))
939            .with_state(metrics_handler)
940    }
941
942    fn route_loki<S>(log_state: LogState) -> Router<S> {
943        Router::new()
944            .route("/api/v1/push", routing::post(loki::loki_ingest))
945            .layer(
946                ServiceBuilder::new()
947                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
948            )
949            .with_state(log_state)
950    }
951
952    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
953        Router::new()
954            // Return fake responsefor HEAD '/' request.
955            .route(
956                "/",
957                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
958            )
959            // Return fake response for Elasticsearch version request.
960            .route("/", routing::get(elasticsearch::handle_get_version))
961            // Return fake response for Elasticsearch license request.
962            .route("/_license", routing::get(elasticsearch::handle_get_license))
963            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
964            .route(
965                "/{index}/_bulk",
966                routing::post(elasticsearch::handle_bulk_api_with_index),
967            )
968            // Return fake response for Elasticsearch ilm request.
969            .route(
970                "/_ilm/policy/{*path}",
971                routing::any((
972                    HttpStatusCode::OK,
973                    elasticsearch::elasticsearch_headers(),
974                    axum::Json(serde_json::json!({})),
975                )),
976            )
977            // Return fake response for Elasticsearch index template request.
978            .route(
979                "/_index_template/{*path}",
980                routing::any((
981                    HttpStatusCode::OK,
982                    elasticsearch::elasticsearch_headers(),
983                    axum::Json(serde_json::json!({})),
984                )),
985            )
986            // Return fake response for Elasticsearch ingest pipeline request.
987            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
988            .route(
989                "/_ingest/{*path}",
990                routing::any((
991                    HttpStatusCode::OK,
992                    elasticsearch::elasticsearch_headers(),
993                    axum::Json(serde_json::json!({})),
994                )),
995            )
996            // Return fake response for Elasticsearch nodes discovery request.
997            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
998            .route(
999                "/_nodes/{*path}",
1000                routing::any((
1001                    HttpStatusCode::OK,
1002                    elasticsearch::elasticsearch_headers(),
1003                    axum::Json(serde_json::json!({})),
1004                )),
1005            )
1006            // Return fake response for Logstash APIs requests.
1007            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1008            .route(
1009                "/logstash/{*path}",
1010                routing::any((
1011                    HttpStatusCode::OK,
1012                    elasticsearch::elasticsearch_headers(),
1013                    axum::Json(serde_json::json!({})),
1014                )),
1015            )
1016            .route(
1017                "/_logstash/{*path}",
1018                routing::any((
1019                    HttpStatusCode::OK,
1020                    elasticsearch::elasticsearch_headers(),
1021                    axum::Json(serde_json::json!({})),
1022                )),
1023            )
1024            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1025            .with_state(log_state)
1026    }
1027
1028    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1029    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1030        Router::new()
1031            .route("/logs", routing::post(event::log_ingester))
1032            .route(
1033                "/pipelines/{pipeline_name}",
1034                routing::get(event::query_pipeline),
1035            )
1036            .route(
1037                "/pipelines/{pipeline_name}",
1038                routing::post(event::add_pipeline),
1039            )
1040            .route(
1041                "/pipelines/{pipeline_name}",
1042                routing::delete(event::delete_pipeline),
1043            )
1044            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1045            .layer(
1046                ServiceBuilder::new()
1047                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1048            )
1049            .with_state(log_state)
1050    }
1051
1052    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1053        Router::new()
1054            .route("/ingest", routing::post(event::log_ingester))
1055            .route(
1056                "/pipelines/{pipeline_name}",
1057                routing::get(event::query_pipeline),
1058            )
1059            .route(
1060                "/pipelines/{pipeline_name}/ddl",
1061                routing::get(event::query_pipeline_ddl),
1062            )
1063            .route(
1064                "/pipelines/{pipeline_name}",
1065                routing::post(event::add_pipeline),
1066            )
1067            .route(
1068                "/pipelines/{pipeline_name}",
1069                routing::delete(event::delete_pipeline),
1070            )
1071            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1072            .layer(
1073                ServiceBuilder::new()
1074                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1075            )
1076            .with_state(log_state)
1077    }
1078
1079    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1080        Router::new()
1081            .route("/sql", routing::get(handler::sql).post(handler::sql))
1082            .route(
1083                "/sql/parse",
1084                routing::get(handler::sql_parse).post(handler::sql_parse),
1085            )
1086            .route(
1087                "/sql/format",
1088                routing::get(handler::sql_format).post(handler::sql_format),
1089            )
1090            .route(
1091                "/promql",
1092                routing::get(handler::promql).post(handler::promql),
1093            )
1094            .with_state(api_state)
1095    }
1096
1097    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1098        Router::new()
1099            .route("/logs", routing::get(logs::logs).post(logs::logs))
1100            .with_state(log_handler)
1101    }
1102
1103    /// Route Prometheus [HTTP API].
1104    ///
1105    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1106    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1107        Router::new()
1108            .route(
1109                "/format_query",
1110                routing::post(format_query).get(format_query),
1111            )
1112            .route("/status/buildinfo", routing::get(build_info_query))
1113            .route("/query", routing::post(instant_query).get(instant_query))
1114            .route("/query_range", routing::post(range_query).get(range_query))
1115            .route("/labels", routing::post(labels_query).get(labels_query))
1116            .route("/series", routing::post(series_query).get(series_query))
1117            .route("/parse_query", routing::post(parse_query).get(parse_query))
1118            .route(
1119                "/label/{label_name}/values",
1120                routing::get(label_values_query),
1121            )
1122            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1123            .with_state(prometheus_handler)
1124    }
1125
1126    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1127    /// called `prom_store`.
1128    ///
1129    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1130    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1131    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1132        Router::new()
1133            .route("/read", routing::post(prom_store::remote_read))
1134            .route("/write", routing::post(prom_store::remote_write))
1135            .with_state(state)
1136    }
1137
1138    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1139        Router::new()
1140            .route("/write", routing::post(influxdb_write_v1))
1141            .route("/api/v2/write", routing::post(influxdb_write_v2))
1142            .layer(
1143                ServiceBuilder::new()
1144                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1145            )
1146            .route("/ping", routing::get(influxdb_ping))
1147            .route("/health", routing::get(influxdb_health))
1148            .with_state(influxdb_handler)
1149    }
1150
1151    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1152        Router::new()
1153            .route("/api/put", routing::post(opentsdb::put))
1154            .with_state(opentsdb_handler)
1155    }
1156
1157    fn route_otlp<S>(
1158        otlp_handler: OpenTelemetryProtocolHandlerRef,
1159        with_metric_engine: bool,
1160    ) -> Router<S> {
1161        Router::new()
1162            .route("/v1/metrics", routing::post(otlp::metrics))
1163            .route("/v1/traces", routing::post(otlp::traces))
1164            .route("/v1/logs", routing::post(otlp::logs))
1165            .layer(
1166                ServiceBuilder::new()
1167                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1168            )
1169            .with_state(OtlpState {
1170                with_metric_engine,
1171                handler: otlp_handler,
1172            })
1173    }
1174
1175    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1176        Router::new()
1177            .route("/config", routing::get(handler::config))
1178            .with_state(state)
1179    }
1180
1181    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1182        Router::new()
1183            .route("/api/services", routing::get(jaeger::handle_get_services))
1184            .route(
1185                "/api/services/{service_name}/operations",
1186                routing::get(jaeger::handle_get_operations_by_service),
1187            )
1188            .route(
1189                "/api/operations",
1190                routing::get(jaeger::handle_get_operations),
1191            )
1192            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1193            .route(
1194                "/api/traces/{trace_id}",
1195                routing::get(jaeger::handle_get_trace),
1196            )
1197            .with_state(handler)
1198    }
1199}
1200
1201pub const HTTP_SERVER: &str = "HTTP_SERVER";
1202
1203#[async_trait]
1204impl Server for HttpServer {
1205    async fn shutdown(&self) -> Result<()> {
1206        let mut shutdown_tx = self.shutdown_tx.lock().await;
1207        if let Some(tx) = shutdown_tx.take()
1208            && tx.send(()).is_err()
1209        {
1210            info!("Receiver dropped, the HTTP server has already exited");
1211        }
1212        info!("Shutdown HTTP server");
1213
1214        Ok(())
1215    }
1216
1217    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1218        let (tx, rx) = oneshot::channel();
1219        let serve = {
1220            let mut shutdown_tx = self.shutdown_tx.lock().await;
1221            ensure!(
1222                shutdown_tx.is_none(),
1223                AlreadyStartedSnafu { server: "HTTP" }
1224            );
1225
1226            let mut app = self.make_app();
1227            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1228                app = configurator
1229                    .configure_http(app, ())
1230                    .await
1231                    .context(OtherSnafu)?;
1232            }
1233            let app = self.build(app)?;
1234            let listener = tokio::net::TcpListener::bind(listening)
1235                .await
1236                .context(AddressBindSnafu { addr: listening })?
1237                .tap_io(|tcp_stream| {
1238                    if let Err(e) = tcp_stream.set_nodelay(true) {
1239                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1240                    }
1241                });
1242            let serve = axum::serve(listener, app.into_make_service());
1243
1244            // FIXME(yingwen): Support keepalive.
1245            // See:
1246            // - https://github.com/tokio-rs/axum/discussions/2939
1247            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1248            // let server = axum::Server::try_bind(&listening)
1249            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1250            //     .tcp_nodelay(true)
1251            //     // Enable TCP keepalive to close the dangling established connections.
1252            //     // It's configured to let the keepalive probes first send after the connection sits
1253            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1254            //     // So the connection will be closed after roughly 1 hour.
1255            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1256            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1257            //     .tcp_keepalive_retries(Some(6))
1258            //     .serve(app.into_make_service());
1259
1260            *shutdown_tx = Some(tx);
1261
1262            serve
1263        };
1264        let listening = serve.local_addr().context(InternalIoSnafu)?;
1265        info!("HTTP server is bound to {}", listening);
1266
1267        common_runtime::spawn_global(async move {
1268            if let Err(e) = serve
1269                .with_graceful_shutdown(rx.map(drop))
1270                .await
1271                .context(InternalIoSnafu)
1272            {
1273                error!(e; "Failed to shutdown http server");
1274            }
1275        });
1276
1277        self.bind_addr = Some(listening);
1278        Ok(())
1279    }
1280
1281    fn name(&self) -> &str {
1282        HTTP_SERVER
1283    }
1284
1285    fn bind_addr(&self) -> Option<SocketAddr> {
1286        self.bind_addr
1287    }
1288}
1289
1290#[cfg(test)]
1291mod test {
1292    use std::future::pending;
1293    use std::io::Cursor;
1294    use std::sync::Arc;
1295
1296    use arrow_ipc::reader::FileReader;
1297    use arrow_schema::DataType;
1298    use axum::handler::Handler;
1299    use axum::http::StatusCode;
1300    use axum::routing::get;
1301    use common_query::Output;
1302    use common_recordbatch::RecordBatches;
1303    use datafusion_expr::LogicalPlan;
1304    use datatypes::prelude::*;
1305    use datatypes::schema::{ColumnSchema, Schema};
1306    use datatypes::vectors::{StringVector, UInt32Vector};
1307    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1308    use query::parser::PromQuery;
1309    use query::query_engine::DescribeResult;
1310    use session::context::QueryContextRef;
1311    use sql::statements::statement::Statement;
1312    use tokio::sync::mpsc;
1313    use tokio::time::Instant;
1314
1315    use super::*;
1316    use crate::error::Error;
1317    use crate::http::test_helpers::TestClient;
1318    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1319
1320    struct DummyInstance {
1321        _tx: mpsc::Sender<(String, Vec<u8>)>,
1322    }
1323
1324    #[async_trait]
1325    impl SqlQueryHandler for DummyInstance {
1326        type Error = Error;
1327
1328        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1329            unimplemented!()
1330        }
1331
1332        async fn do_promql_query(
1333            &self,
1334            _: &PromQuery,
1335            _: QueryContextRef,
1336        ) -> Vec<std::result::Result<Output, Self::Error>> {
1337            unimplemented!()
1338        }
1339
1340        async fn do_exec_plan(
1341            &self,
1342            _stmt: Option<Statement>,
1343            _plan: LogicalPlan,
1344            _query_ctx: QueryContextRef,
1345        ) -> std::result::Result<Output, Self::Error> {
1346            unimplemented!()
1347        }
1348
1349        async fn do_describe(
1350            &self,
1351            _stmt: sql::statements::statement::Statement,
1352            _query_ctx: QueryContextRef,
1353        ) -> Result<Option<DescribeResult>> {
1354            unimplemented!()
1355        }
1356
1357        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1358            Ok(true)
1359        }
1360    }
1361
1362    fn timeout() -> DynamicTimeoutLayer {
1363        DynamicTimeoutLayer::new(Duration::from_millis(10))
1364    }
1365
1366    async fn forever() {
1367        pending().await
1368    }
1369
1370    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1371        make_test_app_custom(tx, HttpOptions::default())
1372    }
1373
1374    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1375        let instance = Arc::new(DummyInstance { _tx: tx });
1376        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1377        let server = HttpServerBuilder::new(options)
1378            .with_sql_handler(sql_instance)
1379            .build();
1380        server.build(server.make_app()).unwrap().route(
1381            "/test/timeout",
1382            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1383        )
1384    }
1385
1386    #[tokio::test]
1387    pub async fn test_cors() {
1388        // cors is on by default
1389        let (tx, _rx) = mpsc::channel(100);
1390        let app = make_test_app(tx);
1391        let client = TestClient::new(app).await;
1392
1393        let res = client.get("/health").send().await;
1394
1395        assert_eq!(res.status(), StatusCode::OK);
1396        assert_eq!(
1397            res.headers()
1398                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1399                .expect("expect cors header origin"),
1400            "*"
1401        );
1402
1403        let res = client.get("/v1/health").send().await;
1404
1405        assert_eq!(res.status(), StatusCode::OK);
1406        assert_eq!(
1407            res.headers()
1408                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1409                .expect("expect cors header origin"),
1410            "*"
1411        );
1412
1413        let res = client
1414            .options("/health")
1415            .header("Access-Control-Request-Headers", "x-greptime-auth")
1416            .header("Access-Control-Request-Method", "DELETE")
1417            .header("Origin", "https://example.com")
1418            .send()
1419            .await;
1420        assert_eq!(res.status(), StatusCode::OK);
1421        assert_eq!(
1422            res.headers()
1423                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1424                .expect("expect cors header origin"),
1425            "*"
1426        );
1427        assert_eq!(
1428            res.headers()
1429                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1430                .expect("expect cors header headers"),
1431            "*"
1432        );
1433        assert_eq!(
1434            res.headers()
1435                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1436                .expect("expect cors header methods"),
1437            "GET,POST,PUT,DELETE,HEAD"
1438        );
1439    }
1440
1441    #[tokio::test]
1442    pub async fn test_cors_custom_origins() {
1443        // cors is on by default
1444        let (tx, _rx) = mpsc::channel(100);
1445        let origin = "https://example.com";
1446
1447        let options = HttpOptions {
1448            cors_allowed_origins: vec![origin.to_string()],
1449            ..Default::default()
1450        };
1451
1452        let app = make_test_app_custom(tx, options);
1453        let client = TestClient::new(app).await;
1454
1455        let res = client.get("/health").header("Origin", origin).send().await;
1456
1457        assert_eq!(res.status(), StatusCode::OK);
1458        assert_eq!(
1459            res.headers()
1460                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1461                .expect("expect cors header origin"),
1462            origin
1463        );
1464
1465        let res = client
1466            .get("/health")
1467            .header("Origin", "https://notallowed.com")
1468            .send()
1469            .await;
1470
1471        assert_eq!(res.status(), StatusCode::OK);
1472        assert!(
1473            !res.headers()
1474                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1475        );
1476    }
1477
1478    #[tokio::test]
1479    pub async fn test_cors_disabled() {
1480        // cors is on by default
1481        let (tx, _rx) = mpsc::channel(100);
1482
1483        let options = HttpOptions {
1484            enable_cors: false,
1485            ..Default::default()
1486        };
1487
1488        let app = make_test_app_custom(tx, options);
1489        let client = TestClient::new(app).await;
1490
1491        let res = client.get("/health").send().await;
1492
1493        assert_eq!(res.status(), StatusCode::OK);
1494        assert!(
1495            !res.headers()
1496                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1497        );
1498    }
1499
1500    #[test]
1501    fn test_http_options_default() {
1502        let default = HttpOptions::default();
1503        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1504        assert_eq!(Duration::from_secs(0), default.timeout)
1505    }
1506
1507    #[tokio::test]
1508    async fn test_http_server_request_timeout() {
1509        common_telemetry::init_default_ut_logging();
1510
1511        let (tx, _rx) = mpsc::channel(100);
1512        let app = make_test_app(tx);
1513        let client = TestClient::new(app).await;
1514        let res = client.get("/test/timeout").send().await;
1515        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1516
1517        let now = Instant::now();
1518        let res = client
1519            .get("/test/timeout")
1520            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1521            .send()
1522            .await;
1523        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1524        let elapsed = now.elapsed();
1525        assert!(elapsed > Duration::from_millis(15));
1526
1527        tokio::time::timeout(
1528            Duration::from_millis(15),
1529            client
1530                .get("/test/timeout")
1531                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1532                .send(),
1533        )
1534        .await
1535        .unwrap_err();
1536
1537        tokio::time::timeout(
1538            Duration::from_millis(15),
1539            client
1540                .get("/test/timeout")
1541                .header(
1542                    GREPTIME_DB_HEADER_TIMEOUT,
1543                    humantime::format_duration(Duration::default()).to_string(),
1544                )
1545                .send(),
1546        )
1547        .await
1548        .unwrap_err();
1549    }
1550
1551    #[tokio::test]
1552    async fn test_schema_for_empty_response() {
1553        let column_schemas = vec![
1554            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1555            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1556        ];
1557        let schema = Arc::new(Schema::new(column_schemas));
1558
1559        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1560        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1561
1562        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1563        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1564            let json_output = &json_resp.output[0];
1565            if let GreptimeQueryOutput::Records(r) = json_output {
1566                assert_eq!(r.num_rows(), 0);
1567                assert_eq!(r.num_cols(), 2);
1568                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1569                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1570            } else {
1571                panic!("invalid output type");
1572            }
1573        } else {
1574            panic!("invalid format")
1575        }
1576    }
1577
1578    #[tokio::test]
1579    async fn test_recordbatches_conversion() {
1580        let column_schemas = vec![
1581            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1582            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1583        ];
1584        let schema = Arc::new(Schema::new(column_schemas));
1585        let columns: Vec<VectorRef> = vec![
1586            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1587            Arc::new(StringVector::from(vec![
1588                None,
1589                Some("hello"),
1590                Some("greptime"),
1591                None,
1592            ])),
1593        ];
1594        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1595
1596        for format in [
1597            ResponseFormat::GreptimedbV1,
1598            ResponseFormat::InfluxdbV1,
1599            ResponseFormat::Csv(true, true),
1600            ResponseFormat::Table,
1601            ResponseFormat::Arrow,
1602            ResponseFormat::Json,
1603            ResponseFormat::Null,
1604        ] {
1605            let recordbatches =
1606                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1607            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1608            let json_resp = match format {
1609                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1610                ResponseFormat::Csv(with_names, with_types) => {
1611                    CsvResponse::from_output(outputs, with_names, with_types).await
1612                }
1613                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1614                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1615                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1616                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1617                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1618            };
1619
1620            match json_resp {
1621                HttpResponse::GreptimedbV1(resp) => {
1622                    let json_output = &resp.output[0];
1623                    if let GreptimeQueryOutput::Records(r) = json_output {
1624                        assert_eq!(r.num_rows(), 4);
1625                        assert_eq!(r.num_cols(), 2);
1626                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1627                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1628                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1629                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1630                    } else {
1631                        panic!("invalid output type");
1632                    }
1633                }
1634                HttpResponse::InfluxdbV1(resp) => {
1635                    let json_output = &resp.results()[0];
1636                    assert_eq!(json_output.num_rows(), 4);
1637                    assert_eq!(json_output.num_cols(), 2);
1638                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1639                    assert_eq!(
1640                        json_output.series[0].values[0][0],
1641                        serde_json::Value::from(1)
1642                    );
1643                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1644                }
1645                HttpResponse::Csv(resp) => {
1646                    let output = &resp.output()[0];
1647                    if let GreptimeQueryOutput::Records(r) = output {
1648                        assert_eq!(r.num_rows(), 4);
1649                        assert_eq!(r.num_cols(), 2);
1650                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1651                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1652                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1653                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1654                    } else {
1655                        panic!("invalid output type");
1656                    }
1657                }
1658
1659                HttpResponse::Table(resp) => {
1660                    let output = &resp.output()[0];
1661                    if let GreptimeQueryOutput::Records(r) = output {
1662                        assert_eq!(r.num_rows(), 4);
1663                        assert_eq!(r.num_cols(), 2);
1664                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1665                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1666                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1667                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1668                    } else {
1669                        panic!("invalid output type");
1670                    }
1671                }
1672
1673                HttpResponse::Arrow(resp) => {
1674                    let output = resp.data;
1675                    let mut reader =
1676                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1677                    let schema = reader.schema();
1678                    assert_eq!(schema.fields[0].name(), "numbers");
1679                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1680                    assert_eq!(schema.fields[1].name(), "strings");
1681                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1682
1683                    let rb = reader.next().unwrap().expect("read record batch failed");
1684                    assert_eq!(rb.num_columns(), 2);
1685                    assert_eq!(rb.num_rows(), 4);
1686                }
1687
1688                HttpResponse::Json(resp) => {
1689                    let output = &resp.output()[0];
1690                    if let GreptimeQueryOutput::Records(r) = output {
1691                        assert_eq!(r.num_rows(), 4);
1692                        assert_eq!(r.num_cols(), 2);
1693                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1694                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1695                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1696                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1697                    } else {
1698                        panic!("invalid output type");
1699                    }
1700                }
1701
1702                HttpResponse::Null(resp) => {
1703                    assert_eq!(resp.rows(), 4);
1704                }
1705
1706                HttpResponse::Error(err) => unreachable!("{err:?}"),
1707            }
1708        }
1709    }
1710
1711    #[test]
1712    fn test_response_format_misc() {
1713        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1714        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1715        assert_eq!(
1716            ResponseFormat::parse("csv"),
1717            Some(ResponseFormat::Csv(false, false))
1718        );
1719        assert_eq!(
1720            ResponseFormat::parse("csvwithnames"),
1721            Some(ResponseFormat::Csv(true, false))
1722        );
1723        assert_eq!(
1724            ResponseFormat::parse("csvwithnamesandtypes"),
1725            Some(ResponseFormat::Csv(true, true))
1726        );
1727        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1728        assert_eq!(
1729            ResponseFormat::parse("greptimedb_v1"),
1730            Some(ResponseFormat::GreptimedbV1)
1731        );
1732        assert_eq!(
1733            ResponseFormat::parse("influxdb_v1"),
1734            Some(ResponseFormat::InfluxdbV1)
1735        );
1736        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1737        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1738
1739        // invalid formats
1740        assert_eq!(ResponseFormat::parse("invalid"), None);
1741        assert_eq!(ResponseFormat::parse(""), None);
1742        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1743
1744        // as str
1745        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1746        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1747        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1748        assert_eq!(ResponseFormat::Table.as_str(), "table");
1749        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1750        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1751        assert_eq!(ResponseFormat::Json.as_str(), "json");
1752        assert_eq!(ResponseFormat::Null.as_str(), "null");
1753        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1754    }
1755}