servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Display;
17use std::net::SocketAddr;
18use std::sync::Mutex as StdMutex;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use auth::UserProviderRef;
23use axum::extract::DefaultBodyLimit;
24use axum::http::StatusCode as HttpStatusCode;
25use axum::response::{IntoResponse, Response};
26use axum::serve::ListenerExt;
27use axum::{Router, middleware, routing};
28use common_base::Plugins;
29use common_base::readable_size::ReadableSize;
30use common_recordbatch::RecordBatch;
31use common_telemetry::{debug, error, info};
32use common_time::Timestamp;
33use common_time::timestamp::TimeUnit;
34use datatypes::data_type::DataType;
35use datatypes::prelude::ConcreteDataType;
36use datatypes::schema::SchemaRef;
37use datatypes::types::json_type_value_to_serde_json;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tower::ServiceBuilder;
48use tower_http::compression::CompressionLayer;
49use tower_http::cors::{AllowOrigin, Any, CorsLayer};
50use tower_http::decompression::RequestDecompressionLayer;
51use tower_http::trace::TraceLayer;
52
53use self::authorize::AuthState;
54use self::result::table_result::TableResponse;
55use crate::configurator::ConfiguratorRef;
56use crate::elasticsearch;
57use crate::error::{
58    AddressBindSnafu, AlreadyStartedSnafu, ConvertSqlValueSnafu, Error, InternalIoSnafu,
59    InvalidHeaderValueSnafu, Result, ToJsonSnafu,
60};
61use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
62use crate::http::otlp::OtlpState;
63use crate::http::prom_store::PromStoreState;
64use crate::http::prometheus::{
65    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
66    range_query, series_query,
67};
68use crate::http::result::arrow_result::ArrowResponse;
69use crate::http::result::csv_result::CsvResponse;
70use crate::http::result::error_result::ErrorResponse;
71use crate::http::result::greptime_result_v1::GreptimedbV1Response;
72use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
73use crate::http::result::json_result::JsonResponse;
74use crate::http::result::null_result::NullResponse;
75use crate::interceptor::LogIngestInterceptorRef;
76use crate::metrics::http_metrics_layer;
77use crate::metrics_handler::MetricsHandler;
78use crate::prometheus_handler::PrometheusHandlerRef;
79use crate::query_handler::sql::ServerSqlQueryHandlerRef;
80use crate::query_handler::{
81    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
82    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
83    PromStoreProtocolHandlerRef,
84};
85use crate::server::Server;
86
87pub mod authorize;
88#[cfg(feature = "dashboard")]
89mod dashboard;
90pub mod dyn_log;
91pub mod event;
92pub mod extractor;
93pub mod handler;
94pub mod header;
95pub mod influxdb;
96pub mod jaeger;
97pub mod logs;
98pub mod loki;
99pub mod mem_prof;
100pub mod opentsdb;
101pub mod otlp;
102pub mod pprof;
103pub mod prom_store;
104pub mod prometheus;
105pub mod result;
106mod timeout;
107pub mod utils;
108
109pub(crate) use timeout::DynamicTimeoutLayer;
110
111mod hints;
112mod read_preference;
113#[cfg(any(test, feature = "testing"))]
114pub mod test_helpers;
115
116pub const HTTP_API_VERSION: &str = "v1";
117pub const HTTP_API_PREFIX: &str = "/v1/";
118/// Default http body limit (64M).
119const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
120
121/// Authorization header
122pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
123
124// TODO(fys): This is a temporary workaround, it will be improved later
125pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
126
127#[derive(Default)]
128pub struct HttpServer {
129    router: StdMutex<Router>,
130    shutdown_tx: Mutex<Option<Sender<()>>>,
131    user_provider: Option<UserProviderRef>,
132
133    // plugins
134    plugins: Plugins,
135
136    // server configs
137    options: HttpOptions,
138    bind_addr: Option<SocketAddr>,
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
142#[serde(default)]
143pub struct HttpOptions {
144    pub addr: String,
145
146    #[serde(with = "humantime_serde")]
147    pub timeout: Duration,
148
149    #[serde(skip)]
150    pub disable_dashboard: bool,
151
152    pub body_limit: ReadableSize,
153
154    /// Validation mode while decoding Prometheus remote write requests.
155    pub prom_validation_mode: PromValidationMode,
156
157    pub cors_allowed_origins: Vec<String>,
158
159    pub enable_cors: bool,
160}
161
162#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
163#[serde(rename_all = "snake_case")]
164pub enum PromValidationMode {
165    /// Force UTF8 validation
166    Strict,
167    /// Allow lossy UTF8 strings
168    Lossy,
169    /// Do not validate UTF8 strings.
170    Unchecked,
171}
172
173impl PromValidationMode {
174    /// Decodes provided bytes to [String] with optional UTF-8 validation.
175    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
176        let result = match self {
177            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
178                Ok(s) => s,
179                Err(e) => {
180                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
181                    return Err(DecodeError::new("invalid utf-8"));
182                }
183            },
184            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
185            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
186        };
187        Ok(result)
188    }
189}
190
191impl Default for HttpOptions {
192    fn default() -> Self {
193        Self {
194            addr: "127.0.0.1:4000".to_string(),
195            timeout: Duration::from_secs(0),
196            disable_dashboard: false,
197            body_limit: DEFAULT_BODY_LIMIT,
198            cors_allowed_origins: Vec::new(),
199            enable_cors: true,
200            prom_validation_mode: PromValidationMode::Strict,
201        }
202    }
203}
204
205#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
206pub struct ColumnSchema {
207    name: String,
208    data_type: String,
209}
210
211impl ColumnSchema {
212    pub fn new(name: String, data_type: String) -> ColumnSchema {
213        ColumnSchema { name, data_type }
214    }
215}
216
217#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
218pub struct OutputSchema {
219    column_schemas: Vec<ColumnSchema>,
220}
221
222impl OutputSchema {
223    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
224        OutputSchema {
225            column_schemas: columns,
226        }
227    }
228}
229
230impl From<SchemaRef> for OutputSchema {
231    fn from(schema: SchemaRef) -> OutputSchema {
232        OutputSchema {
233            column_schemas: schema
234                .column_schemas()
235                .iter()
236                .map(|cs| ColumnSchema {
237                    name: cs.name.clone(),
238                    data_type: cs.data_type.name(),
239                })
240                .collect(),
241        }
242    }
243}
244
245#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
246pub struct HttpRecordsOutput {
247    schema: OutputSchema,
248    rows: Vec<Vec<Value>>,
249    // total_rows is equal to rows.len() in most cases,
250    // the Dashboard query result may be truncated, so we need to return the total_rows.
251    #[serde(default)]
252    total_rows: usize,
253
254    // plan level execution metrics
255    #[serde(skip_serializing_if = "HashMap::is_empty")]
256    #[serde(default)]
257    metrics: HashMap<String, Value>,
258}
259
260impl HttpRecordsOutput {
261    pub fn num_rows(&self) -> usize {
262        self.rows.len()
263    }
264
265    pub fn num_cols(&self) -> usize {
266        self.schema.column_schemas.len()
267    }
268
269    pub fn schema(&self) -> &OutputSchema {
270        &self.schema
271    }
272
273    pub fn rows(&self) -> &Vec<Vec<Value>> {
274        &self.rows
275    }
276}
277
278impl HttpRecordsOutput {
279    pub fn try_new(
280        schema: SchemaRef,
281        recordbatches: Vec<RecordBatch>,
282    ) -> std::result::Result<HttpRecordsOutput, Error> {
283        if recordbatches.is_empty() {
284            Ok(HttpRecordsOutput {
285                schema: OutputSchema::from(schema),
286                rows: vec![],
287                total_rows: 0,
288                metrics: Default::default(),
289            })
290        } else {
291            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
292            let mut rows = Vec::with_capacity(num_rows);
293            let schemas = schema.column_schemas();
294            let num_cols = schema.column_schemas().len();
295            rows.resize_with(num_rows, || Vec::with_capacity(num_cols));
296
297            let mut finished_row_cursor = 0;
298            for recordbatch in recordbatches {
299                for (col_idx, col) in recordbatch.columns().iter().enumerate() {
300                    // safety here: schemas length is equal to the number of columns in the recordbatch
301                    let schema = &schemas[col_idx];
302                    for row_idx in 0..recordbatch.num_rows() {
303                        let value = col.get(row_idx);
304                        let value = if let ConcreteDataType::Json(json_type) = schema.data_type
305                            && let datatypes::value::Value::Binary(bytes) = value
306                        {
307                            json_type_value_to_serde_json(bytes.as_ref(), &json_type.format)
308                                .context(ConvertSqlValueSnafu)?
309                        } else {
310                            serde_json::Value::try_from(col.get(row_idx)).context(ToJsonSnafu)?
311                        };
312
313                        rows[row_idx + finished_row_cursor].push(value);
314                    }
315                }
316                finished_row_cursor += recordbatch.num_rows();
317            }
318
319            Ok(HttpRecordsOutput {
320                schema: OutputSchema::from(schema),
321                total_rows: rows.len(),
322                rows,
323                metrics: Default::default(),
324            })
325        }
326    }
327}
328
329#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
330#[serde(rename_all = "lowercase")]
331pub enum GreptimeQueryOutput {
332    AffectedRows(usize),
333    Records(HttpRecordsOutput),
334}
335
336/// It allows the results of SQL queries to be presented in different formats.
337#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
338pub enum ResponseFormat {
339    Arrow,
340    // (with_names, with_types)
341    Csv(bool, bool),
342    Table,
343    #[default]
344    GreptimedbV1,
345    InfluxdbV1,
346    Json,
347    Null,
348}
349
350impl ResponseFormat {
351    pub fn parse(s: &str) -> Option<Self> {
352        match s {
353            "arrow" => Some(ResponseFormat::Arrow),
354            "csv" => Some(ResponseFormat::Csv(false, false)),
355            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
356            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
357            "table" => Some(ResponseFormat::Table),
358            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
359            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
360            "json" => Some(ResponseFormat::Json),
361            "null" => Some(ResponseFormat::Null),
362            _ => None,
363        }
364    }
365
366    pub fn as_str(&self) -> &'static str {
367        match self {
368            ResponseFormat::Arrow => "arrow",
369            ResponseFormat::Csv(_, _) => "csv",
370            ResponseFormat::Table => "table",
371            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
372            ResponseFormat::InfluxdbV1 => "influxdb_v1",
373            ResponseFormat::Json => "json",
374            ResponseFormat::Null => "null",
375        }
376    }
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum Epoch {
381    Nanosecond,
382    Microsecond,
383    Millisecond,
384    Second,
385}
386
387impl Epoch {
388    pub fn parse(s: &str) -> Option<Epoch> {
389        // Both u and µ indicate microseconds.
390        // epoch = [ns,u,µ,ms,s],
391        // For details, see the Influxdb documents.
392        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
393        match s {
394            "ns" => Some(Epoch::Nanosecond),
395            "u" | "µ" => Some(Epoch::Microsecond),
396            "ms" => Some(Epoch::Millisecond),
397            "s" => Some(Epoch::Second),
398            _ => None, // just returns None for other cases
399        }
400    }
401
402    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
403        match self {
404            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
405            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
406            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
407            Epoch::Second => ts.convert_to(TimeUnit::Second),
408        }
409    }
410}
411
412impl Display for Epoch {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        match self {
415            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
416            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
417            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
418            Epoch::Second => write!(f, "Epoch::Second"),
419        }
420    }
421}
422
423#[derive(Serialize, Deserialize, Debug)]
424pub enum HttpResponse {
425    Arrow(ArrowResponse),
426    Csv(CsvResponse),
427    Table(TableResponse),
428    Error(ErrorResponse),
429    GreptimedbV1(GreptimedbV1Response),
430    InfluxdbV1(InfluxdbV1Response),
431    Json(JsonResponse),
432    Null(NullResponse),
433}
434
435impl HttpResponse {
436    pub fn with_execution_time(self, execution_time: u64) -> Self {
437        match self {
438            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
439            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
440            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
441            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
442            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
443            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
444            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
445            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
446        }
447    }
448
449    pub fn with_limit(self, limit: usize) -> Self {
450        match self {
451            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
452            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
453            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
454            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
455            _ => self,
456        }
457    }
458}
459
460pub fn process_with_limit(
461    mut outputs: Vec<GreptimeQueryOutput>,
462    limit: usize,
463) -> Vec<GreptimeQueryOutput> {
464    outputs
465        .drain(..)
466        .map(|data| match data {
467            GreptimeQueryOutput::Records(mut records) => {
468                if records.rows.len() > limit {
469                    records.rows.truncate(limit);
470                    records.total_rows = limit;
471                }
472                GreptimeQueryOutput::Records(records)
473            }
474            _ => data,
475        })
476        .collect()
477}
478
479impl IntoResponse for HttpResponse {
480    fn into_response(self) -> Response {
481        match self {
482            HttpResponse::Arrow(resp) => resp.into_response(),
483            HttpResponse::Csv(resp) => resp.into_response(),
484            HttpResponse::Table(resp) => resp.into_response(),
485            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
486            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
487            HttpResponse::Json(resp) => resp.into_response(),
488            HttpResponse::Null(resp) => resp.into_response(),
489            HttpResponse::Error(resp) => resp.into_response(),
490        }
491    }
492}
493
494impl From<ArrowResponse> for HttpResponse {
495    fn from(value: ArrowResponse) -> Self {
496        HttpResponse::Arrow(value)
497    }
498}
499
500impl From<CsvResponse> for HttpResponse {
501    fn from(value: CsvResponse) -> Self {
502        HttpResponse::Csv(value)
503    }
504}
505
506impl From<TableResponse> for HttpResponse {
507    fn from(value: TableResponse) -> Self {
508        HttpResponse::Table(value)
509    }
510}
511
512impl From<ErrorResponse> for HttpResponse {
513    fn from(value: ErrorResponse) -> Self {
514        HttpResponse::Error(value)
515    }
516}
517
518impl From<GreptimedbV1Response> for HttpResponse {
519    fn from(value: GreptimedbV1Response) -> Self {
520        HttpResponse::GreptimedbV1(value)
521    }
522}
523
524impl From<InfluxdbV1Response> for HttpResponse {
525    fn from(value: InfluxdbV1Response) -> Self {
526        HttpResponse::InfluxdbV1(value)
527    }
528}
529
530impl From<JsonResponse> for HttpResponse {
531    fn from(value: JsonResponse) -> Self {
532        HttpResponse::Json(value)
533    }
534}
535
536impl From<NullResponse> for HttpResponse {
537    fn from(value: NullResponse) -> Self {
538        HttpResponse::Null(value)
539    }
540}
541
542#[derive(Clone)]
543pub struct ApiState {
544    pub sql_handler: ServerSqlQueryHandlerRef,
545}
546
547#[derive(Clone)]
548pub struct GreptimeOptionsConfigState {
549    pub greptime_config_options: String,
550}
551
552#[derive(Default)]
553pub struct HttpServerBuilder {
554    options: HttpOptions,
555    plugins: Plugins,
556    user_provider: Option<UserProviderRef>,
557    router: Router,
558}
559
560impl HttpServerBuilder {
561    pub fn new(options: HttpOptions) -> Self {
562        Self {
563            options,
564            plugins: Plugins::default(),
565            user_provider: None,
566            router: Router::new(),
567        }
568    }
569
570    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
571        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
572
573        Self {
574            router: self
575                .router
576                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
577            ..self
578        }
579    }
580
581    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
582        let logs_router = HttpServer::route_logs(logs_handler);
583
584        Self {
585            router: self
586                .router
587                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
588            ..self
589        }
590    }
591
592    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
593        Self {
594            router: self.router.nest(
595                &format!("/{HTTP_API_VERSION}/opentsdb"),
596                HttpServer::route_opentsdb(handler),
597            ),
598            ..self
599        }
600    }
601
602    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
603        Self {
604            router: self.router.nest(
605                &format!("/{HTTP_API_VERSION}/influxdb"),
606                HttpServer::route_influxdb(handler),
607            ),
608            ..self
609        }
610    }
611
612    pub fn with_prom_handler(
613        self,
614        handler: PromStoreProtocolHandlerRef,
615        pipeline_handler: Option<PipelineHandlerRef>,
616        prom_store_with_metric_engine: bool,
617        prom_validation_mode: PromValidationMode,
618    ) -> Self {
619        let state = PromStoreState {
620            prom_store_handler: handler,
621            pipeline_handler,
622            prom_store_with_metric_engine,
623            prom_validation_mode,
624        };
625
626        Self {
627            router: self.router.nest(
628                &format!("/{HTTP_API_VERSION}/prometheus"),
629                HttpServer::route_prom(state),
630            ),
631            ..self
632        }
633    }
634
635    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
636        Self {
637            router: self.router.nest(
638                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
639                HttpServer::route_prometheus(handler),
640            ),
641            ..self
642        }
643    }
644
645    pub fn with_otlp_handler(
646        self,
647        handler: OpenTelemetryProtocolHandlerRef,
648        with_metric_engine: bool,
649    ) -> Self {
650        Self {
651            router: self.router.nest(
652                &format!("/{HTTP_API_VERSION}/otlp"),
653                HttpServer::route_otlp(handler, with_metric_engine),
654            ),
655            ..self
656        }
657    }
658
659    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
660        Self {
661            user_provider: Some(user_provider),
662            ..self
663        }
664    }
665
666    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
667        Self {
668            router: self.router.merge(HttpServer::route_metrics(handler)),
669            ..self
670        }
671    }
672
673    pub fn with_log_ingest_handler(
674        self,
675        handler: PipelineHandlerRef,
676        validator: Option<LogValidatorRef>,
677        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
678    ) -> Self {
679        let log_state = LogState {
680            log_handler: handler,
681            log_validator: validator,
682            ingest_interceptor,
683        };
684
685        let router = self.router.nest(
686            &format!("/{HTTP_API_VERSION}"),
687            HttpServer::route_pipelines(log_state.clone()),
688        );
689        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
690        let router = router.nest(
691            &format!("/{HTTP_API_VERSION}/events"),
692            #[allow(deprecated)]
693            HttpServer::route_log_deprecated(log_state.clone()),
694        );
695
696        let router = router.nest(
697            &format!("/{HTTP_API_VERSION}/loki"),
698            HttpServer::route_loki(log_state.clone()),
699        );
700
701        let router = router.nest(
702            &format!("/{HTTP_API_VERSION}/elasticsearch"),
703            HttpServer::route_elasticsearch(log_state.clone()),
704        );
705
706        let router = router.nest(
707            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
708            Router::new()
709                .route("/", routing::get(elasticsearch::handle_get_version))
710                .with_state(log_state),
711        );
712
713        Self { router, ..self }
714    }
715
716    pub fn with_plugins(self, plugins: Plugins) -> Self {
717        Self { plugins, ..self }
718    }
719
720    pub fn with_greptime_config_options(self, opts: String) -> Self {
721        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
722            greptime_config_options: opts,
723        });
724
725        Self {
726            router: self.router.merge(config_router),
727            ..self
728        }
729    }
730
731    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
732        Self {
733            router: self.router.nest(
734                &format!("/{HTTP_API_VERSION}/jaeger"),
735                HttpServer::route_jaeger(handler),
736            ),
737            ..self
738        }
739    }
740
741    pub fn with_extra_router(self, router: Router) -> Self {
742        Self {
743            router: self.router.merge(router),
744            ..self
745        }
746    }
747
748    pub fn build(self) -> HttpServer {
749        HttpServer {
750            options: self.options,
751            user_provider: self.user_provider,
752            shutdown_tx: Mutex::new(None),
753            plugins: self.plugins,
754            router: StdMutex::new(self.router),
755            bind_addr: None,
756        }
757    }
758}
759
760impl HttpServer {
761    /// Gets the router and adds necessary root routes (health, status, dashboard).
762    pub fn make_app(&self) -> Router {
763        let mut router = {
764            let router = self.router.lock().unwrap();
765            router.clone()
766        };
767
768        router = router
769            .route(
770                "/health",
771                routing::get(handler::health).post(handler::health),
772            )
773            .route(
774                &format!("/{HTTP_API_VERSION}/health"),
775                routing::get(handler::health).post(handler::health),
776            )
777            .route(
778                "/ready",
779                routing::get(handler::health).post(handler::health),
780            );
781
782        router = router.route("/status", routing::get(handler::status));
783
784        #[cfg(feature = "dashboard")]
785        {
786            if !self.options.disable_dashboard {
787                info!("Enable dashboard service at '/dashboard'");
788                // redirect /dashboard to /dashboard/
789                router = router.route(
790                    "/dashboard",
791                    routing::get(|uri: axum::http::uri::Uri| async move {
792                        let path = uri.path();
793                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
794
795                        let new_uri = format!("{}/{}", path, query);
796                        axum::response::Redirect::permanent(&new_uri)
797                    }),
798                );
799
800                // "/dashboard" and "/dashboard/" are two different paths in Axum.
801                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
802                // So we explicitly route "/dashboard/" here.
803                router = router
804                    .route(
805                        "/dashboard/",
806                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
807                    )
808                    .route(
809                        "/dashboard/{*x}",
810                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
811                    );
812            }
813        }
814
815        // Add a layer to collect HTTP metrics for axum.
816        router = router.route_layer(middleware::from_fn(http_metrics_layer));
817
818        router
819    }
820
821    /// Attaches middlewares and debug routes to the router.
822    /// Callers should call this method after [HttpServer::make_app()].
823    pub fn build(&self, router: Router) -> Result<Router> {
824        let timeout_layer = if self.options.timeout != Duration::default() {
825            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
826        } else {
827            info!("HTTP server timeout is disabled");
828            None
829        };
830        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
831            Some(
832                ServiceBuilder::new()
833                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
834            )
835        } else {
836            info!("HTTP server body limit is disabled");
837            None
838        };
839        let cors_layer = if self.options.enable_cors {
840            Some(
841                CorsLayer::new()
842                    .allow_methods([
843                        Method::GET,
844                        Method::POST,
845                        Method::PUT,
846                        Method::DELETE,
847                        Method::HEAD,
848                    ])
849                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
850                        AllowOrigin::from(Any)
851                    } else {
852                        AllowOrigin::from(
853                            self.options
854                                .cors_allowed_origins
855                                .iter()
856                                .map(|s| {
857                                    HeaderValue::from_str(s.as_str())
858                                        .context(InvalidHeaderValueSnafu)
859                                })
860                                .collect::<Result<Vec<HeaderValue>>>()?,
861                        )
862                    })
863                    .allow_headers(Any),
864            )
865        } else {
866            info!("HTTP server cross-origin is disabled");
867            None
868        };
869
870        Ok(router
871            // middlewares
872            .layer(
873                ServiceBuilder::new()
874                    // disable on failure tracing. because printing out isn't very helpful,
875                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
876                    .layer(TraceLayer::new_for_http().on_failure(()))
877                    .option_layer(cors_layer)
878                    .option_layer(timeout_layer)
879                    .option_layer(body_limit_layer)
880                    // auth layer
881                    .layer(middleware::from_fn_with_state(
882                        AuthState::new(self.user_provider.clone()),
883                        authorize::check_http_auth,
884                    ))
885                    .layer(middleware::from_fn(hints::extract_hints))
886                    .layer(middleware::from_fn(
887                        read_preference::extract_read_preference,
888                    )),
889            )
890            // Handlers for debug, we don't expect a timeout.
891            .nest(
892                "/debug",
893                Router::new()
894                    // handler for changing log level dynamically
895                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
896                    .nest(
897                        "/prof",
898                        Router::new()
899                            .route("/cpu", routing::post(pprof::pprof_handler))
900                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
901                            .route(
902                                "/mem/activate",
903                                routing::post(mem_prof::activate_heap_prof_handler),
904                            )
905                            .route(
906                                "/mem/deactivate",
907                                routing::post(mem_prof::deactivate_heap_prof_handler),
908                            )
909                            .route(
910                                "/mem/status",
911                                routing::get(mem_prof::heap_prof_status_handler),
912                            ),
913                    ),
914            ))
915    }
916
917    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
918        Router::new()
919            .route("/metrics", routing::get(handler::metrics))
920            .with_state(metrics_handler)
921    }
922
923    fn route_loki<S>(log_state: LogState) -> Router<S> {
924        Router::new()
925            .route("/api/v1/push", routing::post(loki::loki_ingest))
926            .layer(
927                ServiceBuilder::new()
928                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
929            )
930            .with_state(log_state)
931    }
932
933    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
934        Router::new()
935            // Return fake responsefor HEAD '/' request.
936            .route(
937                "/",
938                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
939            )
940            // Return fake response for Elasticsearch version request.
941            .route("/", routing::get(elasticsearch::handle_get_version))
942            // Return fake response for Elasticsearch license request.
943            .route("/_license", routing::get(elasticsearch::handle_get_license))
944            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
945            .route(
946                "/{index}/_bulk",
947                routing::post(elasticsearch::handle_bulk_api_with_index),
948            )
949            // Return fake response for Elasticsearch ilm request.
950            .route(
951                "/_ilm/policy/{*path}",
952                routing::any((
953                    HttpStatusCode::OK,
954                    elasticsearch::elasticsearch_headers(),
955                    axum::Json(serde_json::json!({})),
956                )),
957            )
958            // Return fake response for Elasticsearch index template request.
959            .route(
960                "/_index_template/{*path}",
961                routing::any((
962                    HttpStatusCode::OK,
963                    elasticsearch::elasticsearch_headers(),
964                    axum::Json(serde_json::json!({})),
965                )),
966            )
967            // Return fake response for Elasticsearch ingest pipeline request.
968            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
969            .route(
970                "/_ingest/{*path}",
971                routing::any((
972                    HttpStatusCode::OK,
973                    elasticsearch::elasticsearch_headers(),
974                    axum::Json(serde_json::json!({})),
975                )),
976            )
977            // Return fake response for Elasticsearch nodes discovery request.
978            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
979            .route(
980                "/_nodes/{*path}",
981                routing::any((
982                    HttpStatusCode::OK,
983                    elasticsearch::elasticsearch_headers(),
984                    axum::Json(serde_json::json!({})),
985                )),
986            )
987            // Return fake response for Logstash APIs requests.
988            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
989            .route(
990                "/logstash/{*path}",
991                routing::any((
992                    HttpStatusCode::OK,
993                    elasticsearch::elasticsearch_headers(),
994                    axum::Json(serde_json::json!({})),
995                )),
996            )
997            .route(
998                "/_logstash/{*path}",
999                routing::any((
1000                    HttpStatusCode::OK,
1001                    elasticsearch::elasticsearch_headers(),
1002                    axum::Json(serde_json::json!({})),
1003                )),
1004            )
1005            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1006            .with_state(log_state)
1007    }
1008
1009    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1010    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1011        Router::new()
1012            .route("/logs", routing::post(event::log_ingester))
1013            .route(
1014                "/pipelines/{pipeline_name}",
1015                routing::get(event::query_pipeline),
1016            )
1017            .route(
1018                "/pipelines/{pipeline_name}",
1019                routing::post(event::add_pipeline),
1020            )
1021            .route(
1022                "/pipelines/{pipeline_name}",
1023                routing::delete(event::delete_pipeline),
1024            )
1025            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1026            .layer(
1027                ServiceBuilder::new()
1028                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1029            )
1030            .with_state(log_state)
1031    }
1032
1033    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1034        Router::new()
1035            .route("/ingest", routing::post(event::log_ingester))
1036            .route(
1037                "/pipelines/{pipeline_name}",
1038                routing::get(event::query_pipeline),
1039            )
1040            .route(
1041                "/pipelines/{pipeline_name}/ddl",
1042                routing::get(event::query_pipeline_ddl),
1043            )
1044            .route(
1045                "/pipelines/{pipeline_name}",
1046                routing::post(event::add_pipeline),
1047            )
1048            .route(
1049                "/pipelines/{pipeline_name}",
1050                routing::delete(event::delete_pipeline),
1051            )
1052            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1053            .layer(
1054                ServiceBuilder::new()
1055                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1056            )
1057            .with_state(log_state)
1058    }
1059
1060    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1061        Router::new()
1062            .route("/sql", routing::get(handler::sql).post(handler::sql))
1063            .route(
1064                "/sql/parse",
1065                routing::get(handler::sql_parse).post(handler::sql_parse),
1066            )
1067            .route(
1068                "/sql/format",
1069                routing::get(handler::sql_format).post(handler::sql_format),
1070            )
1071            .route(
1072                "/promql",
1073                routing::get(handler::promql).post(handler::promql),
1074            )
1075            .with_state(api_state)
1076    }
1077
1078    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1079        Router::new()
1080            .route("/logs", routing::get(logs::logs).post(logs::logs))
1081            .with_state(log_handler)
1082    }
1083
1084    /// Route Prometheus [HTTP API].
1085    ///
1086    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1087    fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1088        Router::new()
1089            .route(
1090                "/format_query",
1091                routing::post(format_query).get(format_query),
1092            )
1093            .route("/status/buildinfo", routing::get(build_info_query))
1094            .route("/query", routing::post(instant_query).get(instant_query))
1095            .route("/query_range", routing::post(range_query).get(range_query))
1096            .route("/labels", routing::post(labels_query).get(labels_query))
1097            .route("/series", routing::post(series_query).get(series_query))
1098            .route("/parse_query", routing::post(parse_query).get(parse_query))
1099            .route(
1100                "/label/{label_name}/values",
1101                routing::get(label_values_query),
1102            )
1103            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1104            .with_state(prometheus_handler)
1105    }
1106
1107    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1108    /// called `prom_store`.
1109    ///
1110    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1111    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1112    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1113        Router::new()
1114            .route("/read", routing::post(prom_store::remote_read))
1115            .route("/write", routing::post(prom_store::remote_write))
1116            .with_state(state)
1117    }
1118
1119    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1120        Router::new()
1121            .route("/write", routing::post(influxdb_write_v1))
1122            .route("/api/v2/write", routing::post(influxdb_write_v2))
1123            .layer(
1124                ServiceBuilder::new()
1125                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1126            )
1127            .route("/ping", routing::get(influxdb_ping))
1128            .route("/health", routing::get(influxdb_health))
1129            .with_state(influxdb_handler)
1130    }
1131
1132    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1133        Router::new()
1134            .route("/api/put", routing::post(opentsdb::put))
1135            .with_state(opentsdb_handler)
1136    }
1137
1138    fn route_otlp<S>(
1139        otlp_handler: OpenTelemetryProtocolHandlerRef,
1140        with_metric_engine: bool,
1141    ) -> Router<S> {
1142        Router::new()
1143            .route("/v1/metrics", routing::post(otlp::metrics))
1144            .route("/v1/traces", routing::post(otlp::traces))
1145            .route("/v1/logs", routing::post(otlp::logs))
1146            .layer(
1147                ServiceBuilder::new()
1148                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1149            )
1150            .with_state(OtlpState {
1151                with_metric_engine,
1152                handler: otlp_handler,
1153            })
1154    }
1155
1156    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1157        Router::new()
1158            .route("/config", routing::get(handler::config))
1159            .with_state(state)
1160    }
1161
1162    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1163        Router::new()
1164            .route("/api/services", routing::get(jaeger::handle_get_services))
1165            .route(
1166                "/api/services/{service_name}/operations",
1167                routing::get(jaeger::handle_get_operations_by_service),
1168            )
1169            .route(
1170                "/api/operations",
1171                routing::get(jaeger::handle_get_operations),
1172            )
1173            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1174            .route(
1175                "/api/traces/{trace_id}",
1176                routing::get(jaeger::handle_get_trace),
1177            )
1178            .with_state(handler)
1179    }
1180}
1181
1182pub const HTTP_SERVER: &str = "HTTP_SERVER";
1183
1184#[async_trait]
1185impl Server for HttpServer {
1186    async fn shutdown(&self) -> Result<()> {
1187        let mut shutdown_tx = self.shutdown_tx.lock().await;
1188        if let Some(tx) = shutdown_tx.take()
1189            && tx.send(()).is_err()
1190        {
1191            info!("Receiver dropped, the HTTP server has already exited");
1192        }
1193        info!("Shutdown HTTP server");
1194
1195        Ok(())
1196    }
1197
1198    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1199        let (tx, rx) = oneshot::channel();
1200        let serve = {
1201            let mut shutdown_tx = self.shutdown_tx.lock().await;
1202            ensure!(
1203                shutdown_tx.is_none(),
1204                AlreadyStartedSnafu { server: "HTTP" }
1205            );
1206
1207            let mut app = self.make_app();
1208            if let Some(configurator) = self.plugins.get::<ConfiguratorRef>() {
1209                app = configurator.config_http(app);
1210            }
1211            let app = self.build(app)?;
1212            let listener = tokio::net::TcpListener::bind(listening)
1213                .await
1214                .context(AddressBindSnafu { addr: listening })?
1215                .tap_io(|tcp_stream| {
1216                    if let Err(e) = tcp_stream.set_nodelay(true) {
1217                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1218                    }
1219                });
1220            let serve = axum::serve(listener, app.into_make_service());
1221
1222            // FIXME(yingwen): Support keepalive.
1223            // See:
1224            // - https://github.com/tokio-rs/axum/discussions/2939
1225            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1226            // let server = axum::Server::try_bind(&listening)
1227            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1228            //     .tcp_nodelay(true)
1229            //     // Enable TCP keepalive to close the dangling established connections.
1230            //     // It's configured to let the keepalive probes first send after the connection sits
1231            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1232            //     // So the connection will be closed after roughly 1 hour.
1233            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1234            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1235            //     .tcp_keepalive_retries(Some(6))
1236            //     .serve(app.into_make_service());
1237
1238            *shutdown_tx = Some(tx);
1239
1240            serve
1241        };
1242        let listening = serve.local_addr().context(InternalIoSnafu)?;
1243        info!("HTTP server is bound to {}", listening);
1244
1245        common_runtime::spawn_global(async move {
1246            if let Err(e) = serve
1247                .with_graceful_shutdown(rx.map(drop))
1248                .await
1249                .context(InternalIoSnafu)
1250            {
1251                error!(e; "Failed to shutdown http server");
1252            }
1253        });
1254
1255        self.bind_addr = Some(listening);
1256        Ok(())
1257    }
1258
1259    fn name(&self) -> &str {
1260        HTTP_SERVER
1261    }
1262
1263    fn bind_addr(&self) -> Option<SocketAddr> {
1264        self.bind_addr
1265    }
1266}
1267
1268#[cfg(test)]
1269mod test {
1270    use std::future::pending;
1271    use std::io::Cursor;
1272    use std::sync::Arc;
1273
1274    use arrow_ipc::reader::FileReader;
1275    use arrow_schema::DataType;
1276    use axum::handler::Handler;
1277    use axum::http::StatusCode;
1278    use axum::routing::get;
1279    use common_query::Output;
1280    use common_recordbatch::RecordBatches;
1281    use datafusion_expr::LogicalPlan;
1282    use datatypes::prelude::*;
1283    use datatypes::schema::{ColumnSchema, Schema};
1284    use datatypes::vectors::{StringVector, UInt32Vector};
1285    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1286    use query::parser::PromQuery;
1287    use query::query_engine::DescribeResult;
1288    use session::context::QueryContextRef;
1289    use sql::statements::statement::Statement;
1290    use tokio::sync::mpsc;
1291    use tokio::time::Instant;
1292
1293    use super::*;
1294    use crate::error::Error;
1295    use crate::http::test_helpers::TestClient;
1296    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1297
1298    struct DummyInstance {
1299        _tx: mpsc::Sender<(String, Vec<u8>)>,
1300    }
1301
1302    #[async_trait]
1303    impl SqlQueryHandler for DummyInstance {
1304        type Error = Error;
1305
1306        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1307            unimplemented!()
1308        }
1309
1310        async fn do_promql_query(
1311            &self,
1312            _: &PromQuery,
1313            _: QueryContextRef,
1314        ) -> Vec<std::result::Result<Output, Self::Error>> {
1315            unimplemented!()
1316        }
1317
1318        async fn do_exec_plan(
1319            &self,
1320            _stmt: Option<Statement>,
1321            _plan: LogicalPlan,
1322            _query_ctx: QueryContextRef,
1323        ) -> std::result::Result<Output, Self::Error> {
1324            unimplemented!()
1325        }
1326
1327        async fn do_describe(
1328            &self,
1329            _stmt: sql::statements::statement::Statement,
1330            _query_ctx: QueryContextRef,
1331        ) -> Result<Option<DescribeResult>> {
1332            unimplemented!()
1333        }
1334
1335        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1336            Ok(true)
1337        }
1338    }
1339
1340    fn timeout() -> DynamicTimeoutLayer {
1341        DynamicTimeoutLayer::new(Duration::from_millis(10))
1342    }
1343
1344    async fn forever() {
1345        pending().await
1346    }
1347
1348    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1349        make_test_app_custom(tx, HttpOptions::default())
1350    }
1351
1352    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1353        let instance = Arc::new(DummyInstance { _tx: tx });
1354        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1355        let server = HttpServerBuilder::new(options)
1356            .with_sql_handler(sql_instance)
1357            .build();
1358        server.build(server.make_app()).unwrap().route(
1359            "/test/timeout",
1360            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1361        )
1362    }
1363
1364    #[tokio::test]
1365    pub async fn test_cors() {
1366        // cors is on by default
1367        let (tx, _rx) = mpsc::channel(100);
1368        let app = make_test_app(tx);
1369        let client = TestClient::new(app).await;
1370
1371        let res = client.get("/health").send().await;
1372
1373        assert_eq!(res.status(), StatusCode::OK);
1374        assert_eq!(
1375            res.headers()
1376                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1377                .expect("expect cors header origin"),
1378            "*"
1379        );
1380
1381        let res = client.get("/v1/health").send().await;
1382
1383        assert_eq!(res.status(), StatusCode::OK);
1384        assert_eq!(
1385            res.headers()
1386                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1387                .expect("expect cors header origin"),
1388            "*"
1389        );
1390
1391        let res = client
1392            .options("/health")
1393            .header("Access-Control-Request-Headers", "x-greptime-auth")
1394            .header("Access-Control-Request-Method", "DELETE")
1395            .header("Origin", "https://example.com")
1396            .send()
1397            .await;
1398        assert_eq!(res.status(), StatusCode::OK);
1399        assert_eq!(
1400            res.headers()
1401                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1402                .expect("expect cors header origin"),
1403            "*"
1404        );
1405        assert_eq!(
1406            res.headers()
1407                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1408                .expect("expect cors header headers"),
1409            "*"
1410        );
1411        assert_eq!(
1412            res.headers()
1413                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1414                .expect("expect cors header methods"),
1415            "GET,POST,PUT,DELETE,HEAD"
1416        );
1417    }
1418
1419    #[tokio::test]
1420    pub async fn test_cors_custom_origins() {
1421        // cors is on by default
1422        let (tx, _rx) = mpsc::channel(100);
1423        let origin = "https://example.com";
1424
1425        let options = HttpOptions {
1426            cors_allowed_origins: vec![origin.to_string()],
1427            ..Default::default()
1428        };
1429
1430        let app = make_test_app_custom(tx, options);
1431        let client = TestClient::new(app).await;
1432
1433        let res = client.get("/health").header("Origin", origin).send().await;
1434
1435        assert_eq!(res.status(), StatusCode::OK);
1436        assert_eq!(
1437            res.headers()
1438                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1439                .expect("expect cors header origin"),
1440            origin
1441        );
1442
1443        let res = client
1444            .get("/health")
1445            .header("Origin", "https://notallowed.com")
1446            .send()
1447            .await;
1448
1449        assert_eq!(res.status(), StatusCode::OK);
1450        assert!(
1451            !res.headers()
1452                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1453        );
1454    }
1455
1456    #[tokio::test]
1457    pub async fn test_cors_disabled() {
1458        // cors is on by default
1459        let (tx, _rx) = mpsc::channel(100);
1460
1461        let options = HttpOptions {
1462            enable_cors: false,
1463            ..Default::default()
1464        };
1465
1466        let app = make_test_app_custom(tx, options);
1467        let client = TestClient::new(app).await;
1468
1469        let res = client.get("/health").send().await;
1470
1471        assert_eq!(res.status(), StatusCode::OK);
1472        assert!(
1473            !res.headers()
1474                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1475        );
1476    }
1477
1478    #[test]
1479    fn test_http_options_default() {
1480        let default = HttpOptions::default();
1481        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1482        assert_eq!(Duration::from_secs(0), default.timeout)
1483    }
1484
1485    #[tokio::test]
1486    async fn test_http_server_request_timeout() {
1487        common_telemetry::init_default_ut_logging();
1488
1489        let (tx, _rx) = mpsc::channel(100);
1490        let app = make_test_app(tx);
1491        let client = TestClient::new(app).await;
1492        let res = client.get("/test/timeout").send().await;
1493        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1494
1495        let now = Instant::now();
1496        let res = client
1497            .get("/test/timeout")
1498            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1499            .send()
1500            .await;
1501        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1502        let elapsed = now.elapsed();
1503        assert!(elapsed > Duration::from_millis(15));
1504
1505        tokio::time::timeout(
1506            Duration::from_millis(15),
1507            client
1508                .get("/test/timeout")
1509                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1510                .send(),
1511        )
1512        .await
1513        .unwrap_err();
1514
1515        tokio::time::timeout(
1516            Duration::from_millis(15),
1517            client
1518                .get("/test/timeout")
1519                .header(
1520                    GREPTIME_DB_HEADER_TIMEOUT,
1521                    humantime::format_duration(Duration::default()).to_string(),
1522                )
1523                .send(),
1524        )
1525        .await
1526        .unwrap_err();
1527    }
1528
1529    #[tokio::test]
1530    async fn test_schema_for_empty_response() {
1531        let column_schemas = vec![
1532            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1533            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1534        ];
1535        let schema = Arc::new(Schema::new(column_schemas));
1536
1537        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1538        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1539
1540        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1541        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1542            let json_output = &json_resp.output[0];
1543            if let GreptimeQueryOutput::Records(r) = json_output {
1544                assert_eq!(r.num_rows(), 0);
1545                assert_eq!(r.num_cols(), 2);
1546                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1547                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1548            } else {
1549                panic!("invalid output type");
1550            }
1551        } else {
1552            panic!("invalid format")
1553        }
1554    }
1555
1556    #[tokio::test]
1557    async fn test_recordbatches_conversion() {
1558        let column_schemas = vec![
1559            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1560            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1561        ];
1562        let schema = Arc::new(Schema::new(column_schemas));
1563        let columns: Vec<VectorRef> = vec![
1564            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1565            Arc::new(StringVector::from(vec![
1566                None,
1567                Some("hello"),
1568                Some("greptime"),
1569                None,
1570            ])),
1571        ];
1572        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1573
1574        for format in [
1575            ResponseFormat::GreptimedbV1,
1576            ResponseFormat::InfluxdbV1,
1577            ResponseFormat::Csv(true, true),
1578            ResponseFormat::Table,
1579            ResponseFormat::Arrow,
1580            ResponseFormat::Json,
1581            ResponseFormat::Null,
1582        ] {
1583            let recordbatches =
1584                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1585            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1586            let json_resp = match format {
1587                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1588                ResponseFormat::Csv(with_names, with_types) => {
1589                    CsvResponse::from_output(outputs, with_names, with_types).await
1590                }
1591                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1592                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1593                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1594                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1595                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1596            };
1597
1598            match json_resp {
1599                HttpResponse::GreptimedbV1(resp) => {
1600                    let json_output = &resp.output[0];
1601                    if let GreptimeQueryOutput::Records(r) = json_output {
1602                        assert_eq!(r.num_rows(), 4);
1603                        assert_eq!(r.num_cols(), 2);
1604                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1605                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1606                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1607                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1608                    } else {
1609                        panic!("invalid output type");
1610                    }
1611                }
1612                HttpResponse::InfluxdbV1(resp) => {
1613                    let json_output = &resp.results()[0];
1614                    assert_eq!(json_output.num_rows(), 4);
1615                    assert_eq!(json_output.num_cols(), 2);
1616                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1617                    assert_eq!(
1618                        json_output.series[0].values[0][0],
1619                        serde_json::Value::from(1)
1620                    );
1621                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1622                }
1623                HttpResponse::Csv(resp) => {
1624                    let output = &resp.output()[0];
1625                    if let GreptimeQueryOutput::Records(r) = output {
1626                        assert_eq!(r.num_rows(), 4);
1627                        assert_eq!(r.num_cols(), 2);
1628                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1629                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1630                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1631                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1632                    } else {
1633                        panic!("invalid output type");
1634                    }
1635                }
1636
1637                HttpResponse::Table(resp) => {
1638                    let output = &resp.output()[0];
1639                    if let GreptimeQueryOutput::Records(r) = output {
1640                        assert_eq!(r.num_rows(), 4);
1641                        assert_eq!(r.num_cols(), 2);
1642                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1643                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1644                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1645                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1646                    } else {
1647                        panic!("invalid output type");
1648                    }
1649                }
1650
1651                HttpResponse::Arrow(resp) => {
1652                    let output = resp.data;
1653                    let mut reader =
1654                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1655                    let schema = reader.schema();
1656                    assert_eq!(schema.fields[0].name(), "numbers");
1657                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1658                    assert_eq!(schema.fields[1].name(), "strings");
1659                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1660
1661                    let rb = reader.next().unwrap().expect("read record batch failed");
1662                    assert_eq!(rb.num_columns(), 2);
1663                    assert_eq!(rb.num_rows(), 4);
1664                }
1665
1666                HttpResponse::Json(resp) => {
1667                    let output = &resp.output()[0];
1668                    if let GreptimeQueryOutput::Records(r) = output {
1669                        assert_eq!(r.num_rows(), 4);
1670                        assert_eq!(r.num_cols(), 2);
1671                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1672                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1673                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1674                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1675                    } else {
1676                        panic!("invalid output type");
1677                    }
1678                }
1679
1680                HttpResponse::Null(resp) => {
1681                    assert_eq!(resp.rows(), 4);
1682                }
1683
1684                HttpResponse::Error(err) => unreachable!("{err:?}"),
1685            }
1686        }
1687    }
1688
1689    #[test]
1690    fn test_response_format_misc() {
1691        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1692        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1693        assert_eq!(
1694            ResponseFormat::parse("csv"),
1695            Some(ResponseFormat::Csv(false, false))
1696        );
1697        assert_eq!(
1698            ResponseFormat::parse("csvwithnames"),
1699            Some(ResponseFormat::Csv(true, false))
1700        );
1701        assert_eq!(
1702            ResponseFormat::parse("csvwithnamesandtypes"),
1703            Some(ResponseFormat::Csv(true, true))
1704        );
1705        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1706        assert_eq!(
1707            ResponseFormat::parse("greptimedb_v1"),
1708            Some(ResponseFormat::GreptimedbV1)
1709        );
1710        assert_eq!(
1711            ResponseFormat::parse("influxdb_v1"),
1712            Some(ResponseFormat::InfluxdbV1)
1713        );
1714        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1715        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1716
1717        // invalid formats
1718        assert_eq!(ResponseFormat::parse("invalid"), None);
1719        assert_eq!(ResponseFormat::parse(""), None);
1720        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1721
1722        // as str
1723        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1724        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1725        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1726        assert_eq!(ResponseFormat::Table.as_str(), "table");
1727        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1728        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1729        assert_eq!(ResponseFormat::Json.as_str(), "json");
1730        assert_eq!(ResponseFormat::Null.as_str(), "null");
1731        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1732    }
1733}