servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Display;
17use std::net::SocketAddr;
18use std::sync::Mutex as StdMutex;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use auth::UserProviderRef;
23use axum::extract::DefaultBodyLimit;
24use axum::http::StatusCode as HttpStatusCode;
25use axum::response::{IntoResponse, Response};
26use axum::serve::ListenerExt;
27use axum::{Router, middleware, routing};
28use common_base::Plugins;
29use common_base::readable_size::ReadableSize;
30use common_recordbatch::RecordBatch;
31use common_telemetry::{debug, error, info};
32use common_time::Timestamp;
33use common_time::timestamp::TimeUnit;
34use datatypes::data_type::DataType;
35use datatypes::prelude::ConcreteDataType;
36use datatypes::schema::SchemaRef;
37use datatypes::types::jsonb_to_serde_json;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tower::ServiceBuilder;
48use tower_http::compression::CompressionLayer;
49use tower_http::cors::{AllowOrigin, Any, CorsLayer};
50use tower_http::decompression::RequestDecompressionLayer;
51use tower_http::trace::TraceLayer;
52
53use self::authorize::AuthState;
54use self::result::table_result::TableResponse;
55use crate::configurator::ConfiguratorRef;
56use crate::elasticsearch;
57use crate::error::{
58    AddressBindSnafu, AlreadyStartedSnafu, ConvertSqlValueSnafu, Error, InternalIoSnafu,
59    InvalidHeaderValueSnafu, Result, ToJsonSnafu,
60};
61use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
62use crate::http::otlp::OtlpState;
63use crate::http::prom_store::PromStoreState;
64use crate::http::prometheus::{
65    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
66    range_query, series_query,
67};
68use crate::http::result::arrow_result::ArrowResponse;
69use crate::http::result::csv_result::CsvResponse;
70use crate::http::result::error_result::ErrorResponse;
71use crate::http::result::greptime_result_v1::GreptimedbV1Response;
72use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
73use crate::http::result::json_result::JsonResponse;
74use crate::http::result::null_result::NullResponse;
75use crate::interceptor::LogIngestInterceptorRef;
76use crate::metrics::http_metrics_layer;
77use crate::metrics_handler::MetricsHandler;
78use crate::prometheus_handler::PrometheusHandlerRef;
79use crate::query_handler::sql::ServerSqlQueryHandlerRef;
80use crate::query_handler::{
81    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
82    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
83    PromStoreProtocolHandlerRef,
84};
85use crate::request_limiter::RequestMemoryLimiter;
86use crate::server::Server;
87
88pub mod authorize;
89#[cfg(feature = "dashboard")]
90mod dashboard;
91pub mod dyn_log;
92pub mod event;
93pub mod extractor;
94pub mod handler;
95pub mod header;
96pub mod influxdb;
97pub mod jaeger;
98pub mod logs;
99pub mod loki;
100pub mod mem_prof;
101mod memory_limit;
102pub mod opentsdb;
103pub mod otlp;
104pub mod pprof;
105pub mod prom_store;
106pub mod prometheus;
107pub mod result;
108mod timeout;
109pub mod utils;
110
111pub(crate) use timeout::DynamicTimeoutLayer;
112
113mod hints;
114mod read_preference;
115#[cfg(any(test, feature = "testing"))]
116pub mod test_helpers;
117
118pub const HTTP_API_VERSION: &str = "v1";
119pub const HTTP_API_PREFIX: &str = "/v1/";
120/// Default http body limit (64M).
121const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
122
123/// Authorization header
124pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
125
126// TODO(fys): This is a temporary workaround, it will be improved later
127pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
128
129#[derive(Default)]
130pub struct HttpServer {
131    router: StdMutex<Router>,
132    shutdown_tx: Mutex<Option<Sender<()>>>,
133    user_provider: Option<UserProviderRef>,
134    memory_limiter: RequestMemoryLimiter,
135
136    // plugins
137    plugins: Plugins,
138
139    // server configs
140    options: HttpOptions,
141    bind_addr: Option<SocketAddr>,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145#[serde(default)]
146pub struct HttpOptions {
147    pub addr: String,
148
149    #[serde(with = "humantime_serde")]
150    pub timeout: Duration,
151
152    #[serde(skip)]
153    pub disable_dashboard: bool,
154
155    pub body_limit: ReadableSize,
156
157    /// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit.
158    pub max_total_body_memory: ReadableSize,
159
160    /// Validation mode while decoding Prometheus remote write requests.
161    pub prom_validation_mode: PromValidationMode,
162
163    pub cors_allowed_origins: Vec<String>,
164
165    pub enable_cors: bool,
166}
167
168#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum PromValidationMode {
171    /// Force UTF8 validation
172    Strict,
173    /// Allow lossy UTF8 strings
174    Lossy,
175    /// Do not validate UTF8 strings.
176    Unchecked,
177}
178
179impl PromValidationMode {
180    /// Decodes provided bytes to [String] with optional UTF-8 validation.
181    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
182        let result = match self {
183            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
184                Ok(s) => s,
185                Err(e) => {
186                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
187                    return Err(DecodeError::new("invalid utf-8"));
188                }
189            },
190            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
191            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
192        };
193        Ok(result)
194    }
195}
196
197impl Default for HttpOptions {
198    fn default() -> Self {
199        Self {
200            addr: "127.0.0.1:4000".to_string(),
201            timeout: Duration::from_secs(0),
202            disable_dashboard: false,
203            body_limit: DEFAULT_BODY_LIMIT,
204            max_total_body_memory: ReadableSize(0),
205            cors_allowed_origins: Vec::new(),
206            enable_cors: true,
207            prom_validation_mode: PromValidationMode::Strict,
208        }
209    }
210}
211
212#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
213pub struct ColumnSchema {
214    name: String,
215    data_type: String,
216}
217
218impl ColumnSchema {
219    pub fn new(name: String, data_type: String) -> ColumnSchema {
220        ColumnSchema { name, data_type }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
225pub struct OutputSchema {
226    column_schemas: Vec<ColumnSchema>,
227}
228
229impl OutputSchema {
230    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
231        OutputSchema {
232            column_schemas: columns,
233        }
234    }
235}
236
237impl From<SchemaRef> for OutputSchema {
238    fn from(schema: SchemaRef) -> OutputSchema {
239        OutputSchema {
240            column_schemas: schema
241                .column_schemas()
242                .iter()
243                .map(|cs| ColumnSchema {
244                    name: cs.name.clone(),
245                    data_type: cs.data_type.name(),
246                })
247                .collect(),
248        }
249    }
250}
251
252#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
253pub struct HttpRecordsOutput {
254    schema: OutputSchema,
255    rows: Vec<Vec<Value>>,
256    // total_rows is equal to rows.len() in most cases,
257    // the Dashboard query result may be truncated, so we need to return the total_rows.
258    #[serde(default)]
259    total_rows: usize,
260
261    // plan level execution metrics
262    #[serde(skip_serializing_if = "HashMap::is_empty")]
263    #[serde(default)]
264    metrics: HashMap<String, Value>,
265}
266
267impl HttpRecordsOutput {
268    pub fn num_rows(&self) -> usize {
269        self.rows.len()
270    }
271
272    pub fn num_cols(&self) -> usize {
273        self.schema.column_schemas.len()
274    }
275
276    pub fn schema(&self) -> &OutputSchema {
277        &self.schema
278    }
279
280    pub fn rows(&self) -> &Vec<Vec<Value>> {
281        &self.rows
282    }
283}
284
285impl HttpRecordsOutput {
286    pub fn try_new(
287        schema: SchemaRef,
288        recordbatches: Vec<RecordBatch>,
289    ) -> std::result::Result<HttpRecordsOutput, Error> {
290        if recordbatches.is_empty() {
291            Ok(HttpRecordsOutput {
292                schema: OutputSchema::from(schema),
293                rows: vec![],
294                total_rows: 0,
295                metrics: Default::default(),
296            })
297        } else {
298            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
299            let mut rows = Vec::with_capacity(num_rows);
300            let schemas = schema.column_schemas();
301            let num_cols = schema.column_schemas().len();
302            rows.resize_with(num_rows, || Vec::with_capacity(num_cols));
303
304            let mut finished_row_cursor = 0;
305            for recordbatch in recordbatches {
306                for (col_idx, col) in recordbatch.columns().iter().enumerate() {
307                    // safety here: schemas length is equal to the number of columns in the recordbatch
308                    let schema = &schemas[col_idx];
309                    for row_idx in 0..recordbatch.num_rows() {
310                        let value = col.get(row_idx);
311                        // TODO(sunng87): is this duplicated with `map_json_type_to_string` in recordbatch?
312                        let value = if let ConcreteDataType::Json(_json_type) = &schema.data_type
313                            && let datatypes::value::Value::Binary(bytes) = value
314                        {
315                            jsonb_to_serde_json(bytes.as_ref()).context(ConvertSqlValueSnafu)?
316                        } else {
317                            serde_json::Value::try_from(col.get(row_idx)).context(ToJsonSnafu)?
318                        };
319
320                        rows[row_idx + finished_row_cursor].push(value);
321                    }
322                }
323                finished_row_cursor += recordbatch.num_rows();
324            }
325
326            Ok(HttpRecordsOutput {
327                schema: OutputSchema::from(schema),
328                total_rows: rows.len(),
329                rows,
330                metrics: Default::default(),
331            })
332        }
333    }
334}
335
336#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
337#[serde(rename_all = "lowercase")]
338pub enum GreptimeQueryOutput {
339    AffectedRows(usize),
340    Records(HttpRecordsOutput),
341}
342
343/// It allows the results of SQL queries to be presented in different formats.
344#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
345pub enum ResponseFormat {
346    Arrow,
347    // (with_names, with_types)
348    Csv(bool, bool),
349    Table,
350    #[default]
351    GreptimedbV1,
352    InfluxdbV1,
353    Json,
354    Null,
355}
356
357impl ResponseFormat {
358    pub fn parse(s: &str) -> Option<Self> {
359        match s {
360            "arrow" => Some(ResponseFormat::Arrow),
361            "csv" => Some(ResponseFormat::Csv(false, false)),
362            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
363            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
364            "table" => Some(ResponseFormat::Table),
365            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
366            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
367            "json" => Some(ResponseFormat::Json),
368            "null" => Some(ResponseFormat::Null),
369            _ => None,
370        }
371    }
372
373    pub fn as_str(&self) -> &'static str {
374        match self {
375            ResponseFormat::Arrow => "arrow",
376            ResponseFormat::Csv(_, _) => "csv",
377            ResponseFormat::Table => "table",
378            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
379            ResponseFormat::InfluxdbV1 => "influxdb_v1",
380            ResponseFormat::Json => "json",
381            ResponseFormat::Null => "null",
382        }
383    }
384}
385
386#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum Epoch {
388    Nanosecond,
389    Microsecond,
390    Millisecond,
391    Second,
392}
393
394impl Epoch {
395    pub fn parse(s: &str) -> Option<Epoch> {
396        // Both u and µ indicate microseconds.
397        // epoch = [ns,u,µ,ms,s],
398        // For details, see the Influxdb documents.
399        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
400        match s {
401            "ns" => Some(Epoch::Nanosecond),
402            "u" | "µ" => Some(Epoch::Microsecond),
403            "ms" => Some(Epoch::Millisecond),
404            "s" => Some(Epoch::Second),
405            _ => None, // just returns None for other cases
406        }
407    }
408
409    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
410        match self {
411            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
412            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
413            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
414            Epoch::Second => ts.convert_to(TimeUnit::Second),
415        }
416    }
417}
418
419impl Display for Epoch {
420    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        match self {
422            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
423            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
424            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
425            Epoch::Second => write!(f, "Epoch::Second"),
426        }
427    }
428}
429
430#[derive(Serialize, Deserialize, Debug)]
431pub enum HttpResponse {
432    Arrow(ArrowResponse),
433    Csv(CsvResponse),
434    Table(TableResponse),
435    Error(ErrorResponse),
436    GreptimedbV1(GreptimedbV1Response),
437    InfluxdbV1(InfluxdbV1Response),
438    Json(JsonResponse),
439    Null(NullResponse),
440}
441
442impl HttpResponse {
443    pub fn with_execution_time(self, execution_time: u64) -> Self {
444        match self {
445            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
446            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
447            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
448            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
449            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
450            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
451            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
452            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
453        }
454    }
455
456    pub fn with_limit(self, limit: usize) -> Self {
457        match self {
458            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
459            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
460            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
461            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
462            _ => self,
463        }
464    }
465}
466
467pub fn process_with_limit(
468    mut outputs: Vec<GreptimeQueryOutput>,
469    limit: usize,
470) -> Vec<GreptimeQueryOutput> {
471    outputs
472        .drain(..)
473        .map(|data| match data {
474            GreptimeQueryOutput::Records(mut records) => {
475                if records.rows.len() > limit {
476                    records.rows.truncate(limit);
477                    records.total_rows = limit;
478                }
479                GreptimeQueryOutput::Records(records)
480            }
481            _ => data,
482        })
483        .collect()
484}
485
486impl IntoResponse for HttpResponse {
487    fn into_response(self) -> Response {
488        match self {
489            HttpResponse::Arrow(resp) => resp.into_response(),
490            HttpResponse::Csv(resp) => resp.into_response(),
491            HttpResponse::Table(resp) => resp.into_response(),
492            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
493            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
494            HttpResponse::Json(resp) => resp.into_response(),
495            HttpResponse::Null(resp) => resp.into_response(),
496            HttpResponse::Error(resp) => resp.into_response(),
497        }
498    }
499}
500
501impl From<ArrowResponse> for HttpResponse {
502    fn from(value: ArrowResponse) -> Self {
503        HttpResponse::Arrow(value)
504    }
505}
506
507impl From<CsvResponse> for HttpResponse {
508    fn from(value: CsvResponse) -> Self {
509        HttpResponse::Csv(value)
510    }
511}
512
513impl From<TableResponse> for HttpResponse {
514    fn from(value: TableResponse) -> Self {
515        HttpResponse::Table(value)
516    }
517}
518
519impl From<ErrorResponse> for HttpResponse {
520    fn from(value: ErrorResponse) -> Self {
521        HttpResponse::Error(value)
522    }
523}
524
525impl From<GreptimedbV1Response> for HttpResponse {
526    fn from(value: GreptimedbV1Response) -> Self {
527        HttpResponse::GreptimedbV1(value)
528    }
529}
530
531impl From<InfluxdbV1Response> for HttpResponse {
532    fn from(value: InfluxdbV1Response) -> Self {
533        HttpResponse::InfluxdbV1(value)
534    }
535}
536
537impl From<JsonResponse> for HttpResponse {
538    fn from(value: JsonResponse) -> Self {
539        HttpResponse::Json(value)
540    }
541}
542
543impl From<NullResponse> for HttpResponse {
544    fn from(value: NullResponse) -> Self {
545        HttpResponse::Null(value)
546    }
547}
548
549#[derive(Clone)]
550pub struct ApiState {
551    pub sql_handler: ServerSqlQueryHandlerRef,
552}
553
554#[derive(Clone)]
555pub struct GreptimeOptionsConfigState {
556    pub greptime_config_options: String,
557}
558
559#[derive(Default)]
560pub struct HttpServerBuilder {
561    options: HttpOptions,
562    plugins: Plugins,
563    user_provider: Option<UserProviderRef>,
564    router: Router,
565}
566
567impl HttpServerBuilder {
568    pub fn new(options: HttpOptions) -> Self {
569        Self {
570            options,
571            plugins: Plugins::default(),
572            user_provider: None,
573            router: Router::new(),
574        }
575    }
576
577    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
578        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
579
580        Self {
581            router: self
582                .router
583                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
584            ..self
585        }
586    }
587
588    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
589        let logs_router = HttpServer::route_logs(logs_handler);
590
591        Self {
592            router: self
593                .router
594                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
595            ..self
596        }
597    }
598
599    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
600        Self {
601            router: self.router.nest(
602                &format!("/{HTTP_API_VERSION}/opentsdb"),
603                HttpServer::route_opentsdb(handler),
604            ),
605            ..self
606        }
607    }
608
609    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
610        Self {
611            router: self.router.nest(
612                &format!("/{HTTP_API_VERSION}/influxdb"),
613                HttpServer::route_influxdb(handler),
614            ),
615            ..self
616        }
617    }
618
619    pub fn with_prom_handler(
620        self,
621        handler: PromStoreProtocolHandlerRef,
622        pipeline_handler: Option<PipelineHandlerRef>,
623        prom_store_with_metric_engine: bool,
624        prom_validation_mode: PromValidationMode,
625    ) -> Self {
626        let state = PromStoreState {
627            prom_store_handler: handler,
628            pipeline_handler,
629            prom_store_with_metric_engine,
630            prom_validation_mode,
631        };
632
633        Self {
634            router: self.router.nest(
635                &format!("/{HTTP_API_VERSION}/prometheus"),
636                HttpServer::route_prom(state),
637            ),
638            ..self
639        }
640    }
641
642    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
643        Self {
644            router: self.router.nest(
645                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
646                HttpServer::route_prometheus(handler),
647            ),
648            ..self
649        }
650    }
651
652    pub fn with_otlp_handler(
653        self,
654        handler: OpenTelemetryProtocolHandlerRef,
655        with_metric_engine: bool,
656    ) -> Self {
657        Self {
658            router: self.router.nest(
659                &format!("/{HTTP_API_VERSION}/otlp"),
660                HttpServer::route_otlp(handler, with_metric_engine),
661            ),
662            ..self
663        }
664    }
665
666    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
667        Self {
668            user_provider: Some(user_provider),
669            ..self
670        }
671    }
672
673    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
674        Self {
675            router: self.router.merge(HttpServer::route_metrics(handler)),
676            ..self
677        }
678    }
679
680    pub fn with_log_ingest_handler(
681        self,
682        handler: PipelineHandlerRef,
683        validator: Option<LogValidatorRef>,
684        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
685    ) -> Self {
686        let log_state = LogState {
687            log_handler: handler,
688            log_validator: validator,
689            ingest_interceptor,
690        };
691
692        let router = self.router.nest(
693            &format!("/{HTTP_API_VERSION}"),
694            HttpServer::route_pipelines(log_state.clone()),
695        );
696        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
697        let router = router.nest(
698            &format!("/{HTTP_API_VERSION}/events"),
699            #[allow(deprecated)]
700            HttpServer::route_log_deprecated(log_state.clone()),
701        );
702
703        let router = router.nest(
704            &format!("/{HTTP_API_VERSION}/loki"),
705            HttpServer::route_loki(log_state.clone()),
706        );
707
708        let router = router.nest(
709            &format!("/{HTTP_API_VERSION}/elasticsearch"),
710            HttpServer::route_elasticsearch(log_state.clone()),
711        );
712
713        let router = router.nest(
714            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
715            Router::new()
716                .route("/", routing::get(elasticsearch::handle_get_version))
717                .with_state(log_state),
718        );
719
720        Self { router, ..self }
721    }
722
723    pub fn with_plugins(self, plugins: Plugins) -> Self {
724        Self { plugins, ..self }
725    }
726
727    pub fn with_greptime_config_options(self, opts: String) -> Self {
728        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
729            greptime_config_options: opts,
730        });
731
732        Self {
733            router: self.router.merge(config_router),
734            ..self
735        }
736    }
737
738    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
739        Self {
740            router: self.router.nest(
741                &format!("/{HTTP_API_VERSION}/jaeger"),
742                HttpServer::route_jaeger(handler),
743            ),
744            ..self
745        }
746    }
747
748    pub fn with_extra_router(self, router: Router) -> Self {
749        Self {
750            router: self.router.merge(router),
751            ..self
752        }
753    }
754
755    pub fn build(self) -> HttpServer {
756        let memory_limiter =
757            RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize);
758        HttpServer {
759            options: self.options,
760            user_provider: self.user_provider,
761            shutdown_tx: Mutex::new(None),
762            plugins: self.plugins,
763            router: StdMutex::new(self.router),
764            bind_addr: None,
765            memory_limiter,
766        }
767    }
768}
769
770impl HttpServer {
771    /// Gets the router and adds necessary root routes (health, status, dashboard).
772    pub fn make_app(&self) -> Router {
773        let mut router = {
774            let router = self.router.lock().unwrap();
775            router.clone()
776        };
777
778        router = router
779            .route(
780                "/health",
781                routing::get(handler::health).post(handler::health),
782            )
783            .route(
784                &format!("/{HTTP_API_VERSION}/health"),
785                routing::get(handler::health).post(handler::health),
786            )
787            .route(
788                "/ready",
789                routing::get(handler::health).post(handler::health),
790            );
791
792        router = router.route("/status", routing::get(handler::status));
793
794        #[cfg(feature = "dashboard")]
795        {
796            if !self.options.disable_dashboard {
797                info!("Enable dashboard service at '/dashboard'");
798                // redirect /dashboard to /dashboard/
799                router = router.route(
800                    "/dashboard",
801                    routing::get(|uri: axum::http::uri::Uri| async move {
802                        let path = uri.path();
803                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
804
805                        let new_uri = format!("{}/{}", path, query);
806                        axum::response::Redirect::permanent(&new_uri)
807                    }),
808                );
809
810                // "/dashboard" and "/dashboard/" are two different paths in Axum.
811                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
812                // So we explicitly route "/dashboard/" here.
813                router = router
814                    .route(
815                        "/dashboard/",
816                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
817                    )
818                    .route(
819                        "/dashboard/{*x}",
820                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
821                    );
822            }
823        }
824
825        // Add a layer to collect HTTP metrics for axum.
826        router = router.route_layer(middleware::from_fn(http_metrics_layer));
827
828        router
829    }
830
831    /// Attaches middlewares and debug routes to the router.
832    /// Callers should call this method after [HttpServer::make_app()].
833    pub fn build(&self, router: Router) -> Result<Router> {
834        let timeout_layer = if self.options.timeout != Duration::default() {
835            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
836        } else {
837            info!("HTTP server timeout is disabled");
838            None
839        };
840        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
841            Some(
842                ServiceBuilder::new()
843                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
844            )
845        } else {
846            info!("HTTP server body limit is disabled");
847            None
848        };
849        let cors_layer = if self.options.enable_cors {
850            Some(
851                CorsLayer::new()
852                    .allow_methods([
853                        Method::GET,
854                        Method::POST,
855                        Method::PUT,
856                        Method::DELETE,
857                        Method::HEAD,
858                    ])
859                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
860                        AllowOrigin::from(Any)
861                    } else {
862                        AllowOrigin::from(
863                            self.options
864                                .cors_allowed_origins
865                                .iter()
866                                .map(|s| {
867                                    HeaderValue::from_str(s.as_str())
868                                        .context(InvalidHeaderValueSnafu)
869                                })
870                                .collect::<Result<Vec<HeaderValue>>>()?,
871                        )
872                    })
873                    .allow_headers(Any),
874            )
875        } else {
876            info!("HTTP server cross-origin is disabled");
877            None
878        };
879
880        Ok(router
881            // middlewares
882            .layer(
883                ServiceBuilder::new()
884                    // disable on failure tracing. because printing out isn't very helpful,
885                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
886                    .layer(TraceLayer::new_for_http().on_failure(()))
887                    .option_layer(cors_layer)
888                    .option_layer(timeout_layer)
889                    .option_layer(body_limit_layer)
890                    // memory limit layer - must be before body is consumed
891                    .layer(middleware::from_fn_with_state(
892                        self.memory_limiter.clone(),
893                        memory_limit::memory_limit_middleware,
894                    ))
895                    // auth layer
896                    .layer(middleware::from_fn_with_state(
897                        AuthState::new(self.user_provider.clone()),
898                        authorize::check_http_auth,
899                    ))
900                    .layer(middleware::from_fn(hints::extract_hints))
901                    .layer(middleware::from_fn(
902                        read_preference::extract_read_preference,
903                    )),
904            )
905            // Handlers for debug, we don't expect a timeout.
906            .nest(
907                "/debug",
908                Router::new()
909                    // handler for changing log level dynamically
910                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
911                    .nest(
912                        "/prof",
913                        Router::new()
914                            .route("/cpu", routing::post(pprof::pprof_handler))
915                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
916                            .route(
917                                "/mem/activate",
918                                routing::post(mem_prof::activate_heap_prof_handler),
919                            )
920                            .route(
921                                "/mem/deactivate",
922                                routing::post(mem_prof::deactivate_heap_prof_handler),
923                            )
924                            .route(
925                                "/mem/status",
926                                routing::get(mem_prof::heap_prof_status_handler),
927                            ) // jemalloc gdump flag status and toggle
928                            .route(
929                                "/mem/gdump",
930                                routing::get(mem_prof::gdump_status_handler)
931                                    .post(mem_prof::gdump_toggle_handler),
932                            ),
933                    ),
934            ))
935    }
936
937    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
938        Router::new()
939            .route("/metrics", routing::get(handler::metrics))
940            .with_state(metrics_handler)
941    }
942
943    fn route_loki<S>(log_state: LogState) -> Router<S> {
944        Router::new()
945            .route("/api/v1/push", routing::post(loki::loki_ingest))
946            .layer(
947                ServiceBuilder::new()
948                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
949            )
950            .with_state(log_state)
951    }
952
953    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
954        Router::new()
955            // Return fake responsefor HEAD '/' request.
956            .route(
957                "/",
958                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
959            )
960            // Return fake response for Elasticsearch version request.
961            .route("/", routing::get(elasticsearch::handle_get_version))
962            // Return fake response for Elasticsearch license request.
963            .route("/_license", routing::get(elasticsearch::handle_get_license))
964            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
965            .route(
966                "/{index}/_bulk",
967                routing::post(elasticsearch::handle_bulk_api_with_index),
968            )
969            // Return fake response for Elasticsearch ilm request.
970            .route(
971                "/_ilm/policy/{*path}",
972                routing::any((
973                    HttpStatusCode::OK,
974                    elasticsearch::elasticsearch_headers(),
975                    axum::Json(serde_json::json!({})),
976                )),
977            )
978            // Return fake response for Elasticsearch index template request.
979            .route(
980                "/_index_template/{*path}",
981                routing::any((
982                    HttpStatusCode::OK,
983                    elasticsearch::elasticsearch_headers(),
984                    axum::Json(serde_json::json!({})),
985                )),
986            )
987            // Return fake response for Elasticsearch ingest pipeline request.
988            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
989            .route(
990                "/_ingest/{*path}",
991                routing::any((
992                    HttpStatusCode::OK,
993                    elasticsearch::elasticsearch_headers(),
994                    axum::Json(serde_json::json!({})),
995                )),
996            )
997            // Return fake response for Elasticsearch nodes discovery request.
998            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
999            .route(
1000                "/_nodes/{*path}",
1001                routing::any((
1002                    HttpStatusCode::OK,
1003                    elasticsearch::elasticsearch_headers(),
1004                    axum::Json(serde_json::json!({})),
1005                )),
1006            )
1007            // Return fake response for Logstash APIs requests.
1008            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1009            .route(
1010                "/logstash/{*path}",
1011                routing::any((
1012                    HttpStatusCode::OK,
1013                    elasticsearch::elasticsearch_headers(),
1014                    axum::Json(serde_json::json!({})),
1015                )),
1016            )
1017            .route(
1018                "/_logstash/{*path}",
1019                routing::any((
1020                    HttpStatusCode::OK,
1021                    elasticsearch::elasticsearch_headers(),
1022                    axum::Json(serde_json::json!({})),
1023                )),
1024            )
1025            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1026            .with_state(log_state)
1027    }
1028
1029    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1030    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1031        Router::new()
1032            .route("/logs", routing::post(event::log_ingester))
1033            .route(
1034                "/pipelines/{pipeline_name}",
1035                routing::get(event::query_pipeline),
1036            )
1037            .route(
1038                "/pipelines/{pipeline_name}",
1039                routing::post(event::add_pipeline),
1040            )
1041            .route(
1042                "/pipelines/{pipeline_name}",
1043                routing::delete(event::delete_pipeline),
1044            )
1045            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1046            .layer(
1047                ServiceBuilder::new()
1048                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1049            )
1050            .with_state(log_state)
1051    }
1052
1053    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1054        Router::new()
1055            .route("/ingest", routing::post(event::log_ingester))
1056            .route(
1057                "/pipelines/{pipeline_name}",
1058                routing::get(event::query_pipeline),
1059            )
1060            .route(
1061                "/pipelines/{pipeline_name}/ddl",
1062                routing::get(event::query_pipeline_ddl),
1063            )
1064            .route(
1065                "/pipelines/{pipeline_name}",
1066                routing::post(event::add_pipeline),
1067            )
1068            .route(
1069                "/pipelines/{pipeline_name}",
1070                routing::delete(event::delete_pipeline),
1071            )
1072            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1073            .layer(
1074                ServiceBuilder::new()
1075                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1076            )
1077            .with_state(log_state)
1078    }
1079
1080    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1081        Router::new()
1082            .route("/sql", routing::get(handler::sql).post(handler::sql))
1083            .route(
1084                "/sql/parse",
1085                routing::get(handler::sql_parse).post(handler::sql_parse),
1086            )
1087            .route(
1088                "/sql/format",
1089                routing::get(handler::sql_format).post(handler::sql_format),
1090            )
1091            .route(
1092                "/promql",
1093                routing::get(handler::promql).post(handler::promql),
1094            )
1095            .with_state(api_state)
1096    }
1097
1098    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1099        Router::new()
1100            .route("/logs", routing::get(logs::logs).post(logs::logs))
1101            .with_state(log_handler)
1102    }
1103
1104    /// Route Prometheus [HTTP API].
1105    ///
1106    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1107    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1108        Router::new()
1109            .route(
1110                "/format_query",
1111                routing::post(format_query).get(format_query),
1112            )
1113            .route("/status/buildinfo", routing::get(build_info_query))
1114            .route("/query", routing::post(instant_query).get(instant_query))
1115            .route("/query_range", routing::post(range_query).get(range_query))
1116            .route("/labels", routing::post(labels_query).get(labels_query))
1117            .route("/series", routing::post(series_query).get(series_query))
1118            .route("/parse_query", routing::post(parse_query).get(parse_query))
1119            .route(
1120                "/label/{label_name}/values",
1121                routing::get(label_values_query),
1122            )
1123            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1124            .with_state(prometheus_handler)
1125    }
1126
1127    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1128    /// called `prom_store`.
1129    ///
1130    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1131    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1132    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1133        Router::new()
1134            .route("/read", routing::post(prom_store::remote_read))
1135            .route("/write", routing::post(prom_store::remote_write))
1136            .with_state(state)
1137    }
1138
1139    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1140        Router::new()
1141            .route("/write", routing::post(influxdb_write_v1))
1142            .route("/api/v2/write", routing::post(influxdb_write_v2))
1143            .layer(
1144                ServiceBuilder::new()
1145                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1146            )
1147            .route("/ping", routing::get(influxdb_ping))
1148            .route("/health", routing::get(influxdb_health))
1149            .with_state(influxdb_handler)
1150    }
1151
1152    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1153        Router::new()
1154            .route("/api/put", routing::post(opentsdb::put))
1155            .with_state(opentsdb_handler)
1156    }
1157
1158    fn route_otlp<S>(
1159        otlp_handler: OpenTelemetryProtocolHandlerRef,
1160        with_metric_engine: bool,
1161    ) -> Router<S> {
1162        Router::new()
1163            .route("/v1/metrics", routing::post(otlp::metrics))
1164            .route("/v1/traces", routing::post(otlp::traces))
1165            .route("/v1/logs", routing::post(otlp::logs))
1166            .layer(
1167                ServiceBuilder::new()
1168                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1169            )
1170            .with_state(OtlpState {
1171                with_metric_engine,
1172                handler: otlp_handler,
1173            })
1174    }
1175
1176    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1177        Router::new()
1178            .route("/config", routing::get(handler::config))
1179            .with_state(state)
1180    }
1181
1182    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1183        Router::new()
1184            .route("/api/services", routing::get(jaeger::handle_get_services))
1185            .route(
1186                "/api/services/{service_name}/operations",
1187                routing::get(jaeger::handle_get_operations_by_service),
1188            )
1189            .route(
1190                "/api/operations",
1191                routing::get(jaeger::handle_get_operations),
1192            )
1193            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1194            .route(
1195                "/api/traces/{trace_id}",
1196                routing::get(jaeger::handle_get_trace),
1197            )
1198            .with_state(handler)
1199    }
1200}
1201
1202pub const HTTP_SERVER: &str = "HTTP_SERVER";
1203
1204#[async_trait]
1205impl Server for HttpServer {
1206    async fn shutdown(&self) -> Result<()> {
1207        let mut shutdown_tx = self.shutdown_tx.lock().await;
1208        if let Some(tx) = shutdown_tx.take()
1209            && tx.send(()).is_err()
1210        {
1211            info!("Receiver dropped, the HTTP server has already exited");
1212        }
1213        info!("Shutdown HTTP server");
1214
1215        Ok(())
1216    }
1217
1218    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1219        let (tx, rx) = oneshot::channel();
1220        let serve = {
1221            let mut shutdown_tx = self.shutdown_tx.lock().await;
1222            ensure!(
1223                shutdown_tx.is_none(),
1224                AlreadyStartedSnafu { server: "HTTP" }
1225            );
1226
1227            let mut app = self.make_app();
1228            if let Some(configurator) = self.plugins.get::<ConfiguratorRef>() {
1229                app = configurator.config_http(app);
1230            }
1231            let app = self.build(app)?;
1232            let listener = tokio::net::TcpListener::bind(listening)
1233                .await
1234                .context(AddressBindSnafu { addr: listening })?
1235                .tap_io(|tcp_stream| {
1236                    if let Err(e) = tcp_stream.set_nodelay(true) {
1237                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1238                    }
1239                });
1240            let serve = axum::serve(listener, app.into_make_service());
1241
1242            // FIXME(yingwen): Support keepalive.
1243            // See:
1244            // - https://github.com/tokio-rs/axum/discussions/2939
1245            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1246            // let server = axum::Server::try_bind(&listening)
1247            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1248            //     .tcp_nodelay(true)
1249            //     // Enable TCP keepalive to close the dangling established connections.
1250            //     // It's configured to let the keepalive probes first send after the connection sits
1251            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1252            //     // So the connection will be closed after roughly 1 hour.
1253            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1254            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1255            //     .tcp_keepalive_retries(Some(6))
1256            //     .serve(app.into_make_service());
1257
1258            *shutdown_tx = Some(tx);
1259
1260            serve
1261        };
1262        let listening = serve.local_addr().context(InternalIoSnafu)?;
1263        info!("HTTP server is bound to {}", listening);
1264
1265        common_runtime::spawn_global(async move {
1266            if let Err(e) = serve
1267                .with_graceful_shutdown(rx.map(drop))
1268                .await
1269                .context(InternalIoSnafu)
1270            {
1271                error!(e; "Failed to shutdown http server");
1272            }
1273        });
1274
1275        self.bind_addr = Some(listening);
1276        Ok(())
1277    }
1278
1279    fn name(&self) -> &str {
1280        HTTP_SERVER
1281    }
1282
1283    fn bind_addr(&self) -> Option<SocketAddr> {
1284        self.bind_addr
1285    }
1286}
1287
1288#[cfg(test)]
1289mod test {
1290    use std::future::pending;
1291    use std::io::Cursor;
1292    use std::sync::Arc;
1293
1294    use arrow_ipc::reader::FileReader;
1295    use arrow_schema::DataType;
1296    use axum::handler::Handler;
1297    use axum::http::StatusCode;
1298    use axum::routing::get;
1299    use common_query::Output;
1300    use common_recordbatch::RecordBatches;
1301    use datafusion_expr::LogicalPlan;
1302    use datatypes::prelude::*;
1303    use datatypes::schema::{ColumnSchema, Schema};
1304    use datatypes::vectors::{StringVector, UInt32Vector};
1305    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1306    use query::parser::PromQuery;
1307    use query::query_engine::DescribeResult;
1308    use session::context::QueryContextRef;
1309    use sql::statements::statement::Statement;
1310    use tokio::sync::mpsc;
1311    use tokio::time::Instant;
1312
1313    use super::*;
1314    use crate::error::Error;
1315    use crate::http::test_helpers::TestClient;
1316    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1317
1318    struct DummyInstance {
1319        _tx: mpsc::Sender<(String, Vec<u8>)>,
1320    }
1321
1322    #[async_trait]
1323    impl SqlQueryHandler for DummyInstance {
1324        type Error = Error;
1325
1326        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1327            unimplemented!()
1328        }
1329
1330        async fn do_promql_query(
1331            &self,
1332            _: &PromQuery,
1333            _: QueryContextRef,
1334        ) -> Vec<std::result::Result<Output, Self::Error>> {
1335            unimplemented!()
1336        }
1337
1338        async fn do_exec_plan(
1339            &self,
1340            _stmt: Option<Statement>,
1341            _plan: LogicalPlan,
1342            _query_ctx: QueryContextRef,
1343        ) -> std::result::Result<Output, Self::Error> {
1344            unimplemented!()
1345        }
1346
1347        async fn do_describe(
1348            &self,
1349            _stmt: sql::statements::statement::Statement,
1350            _query_ctx: QueryContextRef,
1351        ) -> Result<Option<DescribeResult>> {
1352            unimplemented!()
1353        }
1354
1355        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1356            Ok(true)
1357        }
1358    }
1359
1360    fn timeout() -> DynamicTimeoutLayer {
1361        DynamicTimeoutLayer::new(Duration::from_millis(10))
1362    }
1363
1364    async fn forever() {
1365        pending().await
1366    }
1367
1368    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1369        make_test_app_custom(tx, HttpOptions::default())
1370    }
1371
1372    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1373        let instance = Arc::new(DummyInstance { _tx: tx });
1374        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1375        let server = HttpServerBuilder::new(options)
1376            .with_sql_handler(sql_instance)
1377            .build();
1378        server.build(server.make_app()).unwrap().route(
1379            "/test/timeout",
1380            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1381        )
1382    }
1383
1384    #[tokio::test]
1385    pub async fn test_cors() {
1386        // cors is on by default
1387        let (tx, _rx) = mpsc::channel(100);
1388        let app = make_test_app(tx);
1389        let client = TestClient::new(app).await;
1390
1391        let res = client.get("/health").send().await;
1392
1393        assert_eq!(res.status(), StatusCode::OK);
1394        assert_eq!(
1395            res.headers()
1396                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1397                .expect("expect cors header origin"),
1398            "*"
1399        );
1400
1401        let res = client.get("/v1/health").send().await;
1402
1403        assert_eq!(res.status(), StatusCode::OK);
1404        assert_eq!(
1405            res.headers()
1406                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1407                .expect("expect cors header origin"),
1408            "*"
1409        );
1410
1411        let res = client
1412            .options("/health")
1413            .header("Access-Control-Request-Headers", "x-greptime-auth")
1414            .header("Access-Control-Request-Method", "DELETE")
1415            .header("Origin", "https://example.com")
1416            .send()
1417            .await;
1418        assert_eq!(res.status(), StatusCode::OK);
1419        assert_eq!(
1420            res.headers()
1421                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1422                .expect("expect cors header origin"),
1423            "*"
1424        );
1425        assert_eq!(
1426            res.headers()
1427                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1428                .expect("expect cors header headers"),
1429            "*"
1430        );
1431        assert_eq!(
1432            res.headers()
1433                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1434                .expect("expect cors header methods"),
1435            "GET,POST,PUT,DELETE,HEAD"
1436        );
1437    }
1438
1439    #[tokio::test]
1440    pub async fn test_cors_custom_origins() {
1441        // cors is on by default
1442        let (tx, _rx) = mpsc::channel(100);
1443        let origin = "https://example.com";
1444
1445        let options = HttpOptions {
1446            cors_allowed_origins: vec![origin.to_string()],
1447            ..Default::default()
1448        };
1449
1450        let app = make_test_app_custom(tx, options);
1451        let client = TestClient::new(app).await;
1452
1453        let res = client.get("/health").header("Origin", origin).send().await;
1454
1455        assert_eq!(res.status(), StatusCode::OK);
1456        assert_eq!(
1457            res.headers()
1458                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1459                .expect("expect cors header origin"),
1460            origin
1461        );
1462
1463        let res = client
1464            .get("/health")
1465            .header("Origin", "https://notallowed.com")
1466            .send()
1467            .await;
1468
1469        assert_eq!(res.status(), StatusCode::OK);
1470        assert!(
1471            !res.headers()
1472                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1473        );
1474    }
1475
1476    #[tokio::test]
1477    pub async fn test_cors_disabled() {
1478        // cors is on by default
1479        let (tx, _rx) = mpsc::channel(100);
1480
1481        let options = HttpOptions {
1482            enable_cors: false,
1483            ..Default::default()
1484        };
1485
1486        let app = make_test_app_custom(tx, options);
1487        let client = TestClient::new(app).await;
1488
1489        let res = client.get("/health").send().await;
1490
1491        assert_eq!(res.status(), StatusCode::OK);
1492        assert!(
1493            !res.headers()
1494                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1495        );
1496    }
1497
1498    #[test]
1499    fn test_http_options_default() {
1500        let default = HttpOptions::default();
1501        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1502        assert_eq!(Duration::from_secs(0), default.timeout)
1503    }
1504
1505    #[tokio::test]
1506    async fn test_http_server_request_timeout() {
1507        common_telemetry::init_default_ut_logging();
1508
1509        let (tx, _rx) = mpsc::channel(100);
1510        let app = make_test_app(tx);
1511        let client = TestClient::new(app).await;
1512        let res = client.get("/test/timeout").send().await;
1513        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1514
1515        let now = Instant::now();
1516        let res = client
1517            .get("/test/timeout")
1518            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1519            .send()
1520            .await;
1521        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1522        let elapsed = now.elapsed();
1523        assert!(elapsed > Duration::from_millis(15));
1524
1525        tokio::time::timeout(
1526            Duration::from_millis(15),
1527            client
1528                .get("/test/timeout")
1529                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1530                .send(),
1531        )
1532        .await
1533        .unwrap_err();
1534
1535        tokio::time::timeout(
1536            Duration::from_millis(15),
1537            client
1538                .get("/test/timeout")
1539                .header(
1540                    GREPTIME_DB_HEADER_TIMEOUT,
1541                    humantime::format_duration(Duration::default()).to_string(),
1542                )
1543                .send(),
1544        )
1545        .await
1546        .unwrap_err();
1547    }
1548
1549    #[tokio::test]
1550    async fn test_schema_for_empty_response() {
1551        let column_schemas = vec![
1552            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1553            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1554        ];
1555        let schema = Arc::new(Schema::new(column_schemas));
1556
1557        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1558        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1559
1560        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1561        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1562            let json_output = &json_resp.output[0];
1563            if let GreptimeQueryOutput::Records(r) = json_output {
1564                assert_eq!(r.num_rows(), 0);
1565                assert_eq!(r.num_cols(), 2);
1566                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1567                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1568            } else {
1569                panic!("invalid output type");
1570            }
1571        } else {
1572            panic!("invalid format")
1573        }
1574    }
1575
1576    #[tokio::test]
1577    async fn test_recordbatches_conversion() {
1578        let column_schemas = vec![
1579            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1580            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1581        ];
1582        let schema = Arc::new(Schema::new(column_schemas));
1583        let columns: Vec<VectorRef> = vec![
1584            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1585            Arc::new(StringVector::from(vec![
1586                None,
1587                Some("hello"),
1588                Some("greptime"),
1589                None,
1590            ])),
1591        ];
1592        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1593
1594        for format in [
1595            ResponseFormat::GreptimedbV1,
1596            ResponseFormat::InfluxdbV1,
1597            ResponseFormat::Csv(true, true),
1598            ResponseFormat::Table,
1599            ResponseFormat::Arrow,
1600            ResponseFormat::Json,
1601            ResponseFormat::Null,
1602        ] {
1603            let recordbatches =
1604                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1605            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1606            let json_resp = match format {
1607                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1608                ResponseFormat::Csv(with_names, with_types) => {
1609                    CsvResponse::from_output(outputs, with_names, with_types).await
1610                }
1611                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1612                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1613                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1614                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1615                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1616            };
1617
1618            match json_resp {
1619                HttpResponse::GreptimedbV1(resp) => {
1620                    let json_output = &resp.output[0];
1621                    if let GreptimeQueryOutput::Records(r) = json_output {
1622                        assert_eq!(r.num_rows(), 4);
1623                        assert_eq!(r.num_cols(), 2);
1624                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1625                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1626                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1627                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1628                    } else {
1629                        panic!("invalid output type");
1630                    }
1631                }
1632                HttpResponse::InfluxdbV1(resp) => {
1633                    let json_output = &resp.results()[0];
1634                    assert_eq!(json_output.num_rows(), 4);
1635                    assert_eq!(json_output.num_cols(), 2);
1636                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1637                    assert_eq!(
1638                        json_output.series[0].values[0][0],
1639                        serde_json::Value::from(1)
1640                    );
1641                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1642                }
1643                HttpResponse::Csv(resp) => {
1644                    let output = &resp.output()[0];
1645                    if let GreptimeQueryOutput::Records(r) = output {
1646                        assert_eq!(r.num_rows(), 4);
1647                        assert_eq!(r.num_cols(), 2);
1648                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1649                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1650                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1651                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1652                    } else {
1653                        panic!("invalid output type");
1654                    }
1655                }
1656
1657                HttpResponse::Table(resp) => {
1658                    let output = &resp.output()[0];
1659                    if let GreptimeQueryOutput::Records(r) = output {
1660                        assert_eq!(r.num_rows(), 4);
1661                        assert_eq!(r.num_cols(), 2);
1662                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1663                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1664                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1665                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1666                    } else {
1667                        panic!("invalid output type");
1668                    }
1669                }
1670
1671                HttpResponse::Arrow(resp) => {
1672                    let output = resp.data;
1673                    let mut reader =
1674                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1675                    let schema = reader.schema();
1676                    assert_eq!(schema.fields[0].name(), "numbers");
1677                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1678                    assert_eq!(schema.fields[1].name(), "strings");
1679                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1680
1681                    let rb = reader.next().unwrap().expect("read record batch failed");
1682                    assert_eq!(rb.num_columns(), 2);
1683                    assert_eq!(rb.num_rows(), 4);
1684                }
1685
1686                HttpResponse::Json(resp) => {
1687                    let output = &resp.output()[0];
1688                    if let GreptimeQueryOutput::Records(r) = output {
1689                        assert_eq!(r.num_rows(), 4);
1690                        assert_eq!(r.num_cols(), 2);
1691                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1692                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1693                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1694                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1695                    } else {
1696                        panic!("invalid output type");
1697                    }
1698                }
1699
1700                HttpResponse::Null(resp) => {
1701                    assert_eq!(resp.rows(), 4);
1702                }
1703
1704                HttpResponse::Error(err) => unreachable!("{err:?}"),
1705            }
1706        }
1707    }
1708
1709    #[test]
1710    fn test_response_format_misc() {
1711        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1712        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1713        assert_eq!(
1714            ResponseFormat::parse("csv"),
1715            Some(ResponseFormat::Csv(false, false))
1716        );
1717        assert_eq!(
1718            ResponseFormat::parse("csvwithnames"),
1719            Some(ResponseFormat::Csv(true, false))
1720        );
1721        assert_eq!(
1722            ResponseFormat::parse("csvwithnamesandtypes"),
1723            Some(ResponseFormat::Csv(true, true))
1724        );
1725        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1726        assert_eq!(
1727            ResponseFormat::parse("greptimedb_v1"),
1728            Some(ResponseFormat::GreptimedbV1)
1729        );
1730        assert_eq!(
1731            ResponseFormat::parse("influxdb_v1"),
1732            Some(ResponseFormat::InfluxdbV1)
1733        );
1734        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1735        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1736
1737        // invalid formats
1738        assert_eq!(ResponseFormat::parse("invalid"), None);
1739        assert_eq!(ResponseFormat::parse(""), None);
1740        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1741
1742        // as str
1743        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1744        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1745        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1746        assert_eq!(ResponseFormat::Table.as_str(), "table");
1747        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1748        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1749        assert_eq!(ResponseFormat::Json.as_str(), "json");
1750        assert_eq!(ResponseFormat::Null.as_str(), "null");
1751        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1752    }
1753}