servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Display;
17use std::net::SocketAddr;
18use std::sync::Mutex as StdMutex;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use auth::UserProviderRef;
23use axum::extract::DefaultBodyLimit;
24use axum::http::StatusCode as HttpStatusCode;
25use axum::response::{IntoResponse, Response};
26use axum::serve::ListenerExt;
27use axum::{middleware, routing, Router};
28use common_base::readable_size::ReadableSize;
29use common_base::Plugins;
30use common_recordbatch::RecordBatch;
31use common_telemetry::{debug, error, info};
32use common_time::timestamp::TimeUnit;
33use common_time::Timestamp;
34use datatypes::data_type::DataType;
35use datatypes::schema::SchemaRef;
36use datatypes::value::transform_value_ref_to_json_value;
37use event::{LogState, LogValidatorRef};
38use futures::FutureExt;
39use http::{HeaderValue, Method};
40use prost::DecodeError;
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use snafu::{ensure, ResultExt};
44use tokio::sync::oneshot::{self, Sender};
45use tokio::sync::Mutex;
46use tower::ServiceBuilder;
47use tower_http::compression::CompressionLayer;
48use tower_http::cors::{AllowOrigin, Any, CorsLayer};
49use tower_http::decompression::RequestDecompressionLayer;
50use tower_http::trace::TraceLayer;
51
52use self::authorize::AuthState;
53use self::result::table_result::TableResponse;
54use crate::configurator::ConfiguratorRef;
55use crate::elasticsearch;
56use crate::error::{
57    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu, Result,
58    ToJsonSnafu,
59};
60use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
61use crate::http::prom_store::PromStoreState;
62use crate::http::prometheus::{
63    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
64    range_query, series_query,
65};
66use crate::http::result::arrow_result::ArrowResponse;
67use crate::http::result::csv_result::CsvResponse;
68use crate::http::result::error_result::ErrorResponse;
69use crate::http::result::greptime_result_v1::GreptimedbV1Response;
70use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
71use crate::http::result::json_result::JsonResponse;
72use crate::http::result::null_result::NullResponse;
73use crate::interceptor::LogIngestInterceptorRef;
74use crate::metrics::http_metrics_layer;
75use crate::metrics_handler::MetricsHandler;
76use crate::prometheus_handler::PrometheusHandlerRef;
77use crate::query_handler::sql::ServerSqlQueryHandlerRef;
78use crate::query_handler::{
79    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
80    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
81    PromStoreProtocolHandlerRef,
82};
83use crate::server::Server;
84
85pub mod authorize;
86#[cfg(feature = "dashboard")]
87mod dashboard;
88pub mod dyn_log;
89pub mod event;
90pub mod extractor;
91pub mod handler;
92pub mod header;
93pub mod influxdb;
94pub mod jaeger;
95pub mod logs;
96pub mod loki;
97pub mod mem_prof;
98pub mod opentsdb;
99pub mod otlp;
100pub mod pprof;
101pub mod prom_store;
102pub mod prometheus;
103pub mod result;
104mod timeout;
105
106pub(crate) use timeout::DynamicTimeoutLayer;
107
108mod hints;
109mod read_preference;
110#[cfg(any(test, feature = "testing"))]
111pub mod test_helpers;
112
113pub const HTTP_API_VERSION: &str = "v1";
114pub const HTTP_API_PREFIX: &str = "/v1/";
115/// Default http body limit (64M).
116const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
117
118/// Authorization header
119pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
120
121// TODO(fys): This is a temporary workaround, it will be improved later
122pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
123
124#[derive(Default)]
125pub struct HttpServer {
126    router: StdMutex<Router>,
127    shutdown_tx: Mutex<Option<Sender<()>>>,
128    user_provider: Option<UserProviderRef>,
129
130    // plugins
131    plugins: Plugins,
132
133    // server configs
134    options: HttpOptions,
135    bind_addr: Option<SocketAddr>,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
139#[serde(default)]
140pub struct HttpOptions {
141    pub addr: String,
142
143    #[serde(with = "humantime_serde")]
144    pub timeout: Duration,
145
146    #[serde(skip)]
147    pub disable_dashboard: bool,
148
149    pub body_limit: ReadableSize,
150
151    /// Validation mode while decoding Prometheus remote write requests.
152    pub prom_validation_mode: PromValidationMode,
153
154    pub cors_allowed_origins: Vec<String>,
155
156    pub enable_cors: bool,
157}
158
159#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
160#[serde(rename_all = "snake_case")]
161pub enum PromValidationMode {
162    /// Force UTF8 validation
163    Strict,
164    /// Allow lossy UTF8 strings
165    Lossy,
166    /// Do not validate UTF8 strings.
167    Unchecked,
168}
169
170impl PromValidationMode {
171    /// Decodes provided bytes to [String] with optional UTF-8 validation.
172    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
173        let result = match self {
174            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
175                Ok(s) => s,
176                Err(e) => {
177                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
178                    return Err(DecodeError::new("invalid utf-8"));
179                }
180            },
181            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
182            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
183        };
184        Ok(result)
185    }
186}
187
188impl Default for HttpOptions {
189    fn default() -> Self {
190        Self {
191            addr: "127.0.0.1:4000".to_string(),
192            timeout: Duration::from_secs(0),
193            disable_dashboard: false,
194            body_limit: DEFAULT_BODY_LIMIT,
195            cors_allowed_origins: Vec::new(),
196            enable_cors: true,
197            prom_validation_mode: PromValidationMode::Strict,
198        }
199    }
200}
201
202#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
203pub struct ColumnSchema {
204    name: String,
205    data_type: String,
206}
207
208impl ColumnSchema {
209    pub fn new(name: String, data_type: String) -> ColumnSchema {
210        ColumnSchema { name, data_type }
211    }
212}
213
214#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
215pub struct OutputSchema {
216    column_schemas: Vec<ColumnSchema>,
217}
218
219impl OutputSchema {
220    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
221        OutputSchema {
222            column_schemas: columns,
223        }
224    }
225}
226
227impl From<SchemaRef> for OutputSchema {
228    fn from(schema: SchemaRef) -> OutputSchema {
229        OutputSchema {
230            column_schemas: schema
231                .column_schemas()
232                .iter()
233                .map(|cs| ColumnSchema {
234                    name: cs.name.clone(),
235                    data_type: cs.data_type.name(),
236                })
237                .collect(),
238        }
239    }
240}
241
242#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
243pub struct HttpRecordsOutput {
244    schema: OutputSchema,
245    rows: Vec<Vec<Value>>,
246    // total_rows is equal to rows.len() in most cases,
247    // the Dashboard query result may be truncated, so we need to return the total_rows.
248    #[serde(default)]
249    total_rows: usize,
250
251    // plan level execution metrics
252    #[serde(skip_serializing_if = "HashMap::is_empty")]
253    #[serde(default)]
254    metrics: HashMap<String, Value>,
255}
256
257impl HttpRecordsOutput {
258    pub fn num_rows(&self) -> usize {
259        self.rows.len()
260    }
261
262    pub fn num_cols(&self) -> usize {
263        self.schema.column_schemas.len()
264    }
265
266    pub fn schema(&self) -> &OutputSchema {
267        &self.schema
268    }
269
270    pub fn rows(&self) -> &Vec<Vec<Value>> {
271        &self.rows
272    }
273}
274
275impl HttpRecordsOutput {
276    pub fn try_new(
277        schema: SchemaRef,
278        recordbatches: Vec<RecordBatch>,
279    ) -> std::result::Result<HttpRecordsOutput, Error> {
280        if recordbatches.is_empty() {
281            Ok(HttpRecordsOutput {
282                schema: OutputSchema::from(schema),
283                rows: vec![],
284                total_rows: 0,
285                metrics: Default::default(),
286            })
287        } else {
288            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
289            let mut rows = Vec::with_capacity(num_rows);
290            let schemas = schema.column_schemas();
291            let num_cols = schema.column_schemas().len();
292            rows.resize_with(num_rows, || Vec::with_capacity(num_cols));
293
294            let mut finished_row_cursor = 0;
295            for recordbatch in recordbatches {
296                for (col_idx, col) in recordbatch.columns().iter().enumerate() {
297                    // safety here: schemas length is equal to the number of columns in the recordbatch
298                    let schema = &schemas[col_idx];
299                    for row_idx in 0..recordbatch.num_rows() {
300                        let value = transform_value_ref_to_json_value(col.get_ref(row_idx), schema)
301                            .context(ToJsonSnafu)?;
302                        rows[row_idx + finished_row_cursor].push(value);
303                    }
304                }
305                finished_row_cursor += recordbatch.num_rows();
306            }
307
308            Ok(HttpRecordsOutput {
309                schema: OutputSchema::from(schema),
310                total_rows: rows.len(),
311                rows,
312                metrics: Default::default(),
313            })
314        }
315    }
316}
317
318#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
319#[serde(rename_all = "lowercase")]
320pub enum GreptimeQueryOutput {
321    AffectedRows(usize),
322    Records(HttpRecordsOutput),
323}
324
325/// It allows the results of SQL queries to be presented in different formats.
326#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
327pub enum ResponseFormat {
328    Arrow,
329    // (with_names, with_types)
330    Csv(bool, bool),
331    Table,
332    #[default]
333    GreptimedbV1,
334    InfluxdbV1,
335    Json,
336    Null,
337}
338
339impl ResponseFormat {
340    pub fn parse(s: &str) -> Option<Self> {
341        match s {
342            "arrow" => Some(ResponseFormat::Arrow),
343            "csv" => Some(ResponseFormat::Csv(false, false)),
344            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
345            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
346            "table" => Some(ResponseFormat::Table),
347            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
348            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
349            "json" => Some(ResponseFormat::Json),
350            "null" => Some(ResponseFormat::Null),
351            _ => None,
352        }
353    }
354
355    pub fn as_str(&self) -> &'static str {
356        match self {
357            ResponseFormat::Arrow => "arrow",
358            ResponseFormat::Csv(_, _) => "csv",
359            ResponseFormat::Table => "table",
360            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
361            ResponseFormat::InfluxdbV1 => "influxdb_v1",
362            ResponseFormat::Json => "json",
363            ResponseFormat::Null => "null",
364        }
365    }
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Eq)]
369pub enum Epoch {
370    Nanosecond,
371    Microsecond,
372    Millisecond,
373    Second,
374}
375
376impl Epoch {
377    pub fn parse(s: &str) -> Option<Epoch> {
378        // Both u and µ indicate microseconds.
379        // epoch = [ns,u,µ,ms,s],
380        // For details, see the Influxdb documents.
381        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
382        match s {
383            "ns" => Some(Epoch::Nanosecond),
384            "u" | "µ" => Some(Epoch::Microsecond),
385            "ms" => Some(Epoch::Millisecond),
386            "s" => Some(Epoch::Second),
387            _ => None, // just returns None for other cases
388        }
389    }
390
391    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
392        match self {
393            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
394            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
395            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
396            Epoch::Second => ts.convert_to(TimeUnit::Second),
397        }
398    }
399}
400
401impl Display for Epoch {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        match self {
404            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
405            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
406            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
407            Epoch::Second => write!(f, "Epoch::Second"),
408        }
409    }
410}
411
412#[derive(Serialize, Deserialize, Debug)]
413pub enum HttpResponse {
414    Arrow(ArrowResponse),
415    Csv(CsvResponse),
416    Table(TableResponse),
417    Error(ErrorResponse),
418    GreptimedbV1(GreptimedbV1Response),
419    InfluxdbV1(InfluxdbV1Response),
420    Json(JsonResponse),
421    Null(NullResponse),
422}
423
424impl HttpResponse {
425    pub fn with_execution_time(self, execution_time: u64) -> Self {
426        match self {
427            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
428            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
429            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
430            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
431            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
432            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
433            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
434            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
435        }
436    }
437
438    pub fn with_limit(self, limit: usize) -> Self {
439        match self {
440            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
441            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
442            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
443            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
444            _ => self,
445        }
446    }
447}
448
449pub fn process_with_limit(
450    mut outputs: Vec<GreptimeQueryOutput>,
451    limit: usize,
452) -> Vec<GreptimeQueryOutput> {
453    outputs
454        .drain(..)
455        .map(|data| match data {
456            GreptimeQueryOutput::Records(mut records) => {
457                if records.rows.len() > limit {
458                    records.rows.truncate(limit);
459                    records.total_rows = limit;
460                }
461                GreptimeQueryOutput::Records(records)
462            }
463            _ => data,
464        })
465        .collect()
466}
467
468impl IntoResponse for HttpResponse {
469    fn into_response(self) -> Response {
470        match self {
471            HttpResponse::Arrow(resp) => resp.into_response(),
472            HttpResponse::Csv(resp) => resp.into_response(),
473            HttpResponse::Table(resp) => resp.into_response(),
474            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
475            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
476            HttpResponse::Json(resp) => resp.into_response(),
477            HttpResponse::Null(resp) => resp.into_response(),
478            HttpResponse::Error(resp) => resp.into_response(),
479        }
480    }
481}
482
483impl From<ArrowResponse> for HttpResponse {
484    fn from(value: ArrowResponse) -> Self {
485        HttpResponse::Arrow(value)
486    }
487}
488
489impl From<CsvResponse> for HttpResponse {
490    fn from(value: CsvResponse) -> Self {
491        HttpResponse::Csv(value)
492    }
493}
494
495impl From<TableResponse> for HttpResponse {
496    fn from(value: TableResponse) -> Self {
497        HttpResponse::Table(value)
498    }
499}
500
501impl From<ErrorResponse> for HttpResponse {
502    fn from(value: ErrorResponse) -> Self {
503        HttpResponse::Error(value)
504    }
505}
506
507impl From<GreptimedbV1Response> for HttpResponse {
508    fn from(value: GreptimedbV1Response) -> Self {
509        HttpResponse::GreptimedbV1(value)
510    }
511}
512
513impl From<InfluxdbV1Response> for HttpResponse {
514    fn from(value: InfluxdbV1Response) -> Self {
515        HttpResponse::InfluxdbV1(value)
516    }
517}
518
519impl From<JsonResponse> for HttpResponse {
520    fn from(value: JsonResponse) -> Self {
521        HttpResponse::Json(value)
522    }
523}
524
525impl From<NullResponse> for HttpResponse {
526    fn from(value: NullResponse) -> Self {
527        HttpResponse::Null(value)
528    }
529}
530
531#[derive(Clone)]
532pub struct ApiState {
533    pub sql_handler: ServerSqlQueryHandlerRef,
534}
535
536#[derive(Clone)]
537pub struct GreptimeOptionsConfigState {
538    pub greptime_config_options: String,
539}
540
541#[derive(Default)]
542pub struct HttpServerBuilder {
543    options: HttpOptions,
544    plugins: Plugins,
545    user_provider: Option<UserProviderRef>,
546    router: Router,
547}
548
549impl HttpServerBuilder {
550    pub fn new(options: HttpOptions) -> Self {
551        Self {
552            options,
553            plugins: Plugins::default(),
554            user_provider: None,
555            router: Router::new(),
556        }
557    }
558
559    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
560        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
561
562        Self {
563            router: self
564                .router
565                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
566            ..self
567        }
568    }
569
570    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
571        let logs_router = HttpServer::route_logs(logs_handler);
572
573        Self {
574            router: self
575                .router
576                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
577            ..self
578        }
579    }
580
581    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
582        Self {
583            router: self.router.nest(
584                &format!("/{HTTP_API_VERSION}/opentsdb"),
585                HttpServer::route_opentsdb(handler),
586            ),
587            ..self
588        }
589    }
590
591    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
592        Self {
593            router: self.router.nest(
594                &format!("/{HTTP_API_VERSION}/influxdb"),
595                HttpServer::route_influxdb(handler),
596            ),
597            ..self
598        }
599    }
600
601    pub fn with_prom_handler(
602        self,
603        handler: PromStoreProtocolHandlerRef,
604        pipeline_handler: Option<PipelineHandlerRef>,
605        prom_store_with_metric_engine: bool,
606        prom_validation_mode: PromValidationMode,
607    ) -> Self {
608        let state = PromStoreState {
609            prom_store_handler: handler,
610            pipeline_handler,
611            prom_store_with_metric_engine,
612            prom_validation_mode,
613        };
614
615        Self {
616            router: self.router.nest(
617                &format!("/{HTTP_API_VERSION}/prometheus"),
618                HttpServer::route_prom(state),
619            ),
620            ..self
621        }
622    }
623
624    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
625        Self {
626            router: self.router.nest(
627                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
628                HttpServer::route_prometheus(handler),
629            ),
630            ..self
631        }
632    }
633
634    pub fn with_otlp_handler(self, handler: OpenTelemetryProtocolHandlerRef) -> Self {
635        Self {
636            router: self.router.nest(
637                &format!("/{HTTP_API_VERSION}/otlp"),
638                HttpServer::route_otlp(handler),
639            ),
640            ..self
641        }
642    }
643
644    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
645        Self {
646            user_provider: Some(user_provider),
647            ..self
648        }
649    }
650
651    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
652        Self {
653            router: self.router.merge(HttpServer::route_metrics(handler)),
654            ..self
655        }
656    }
657
658    pub fn with_log_ingest_handler(
659        self,
660        handler: PipelineHandlerRef,
661        validator: Option<LogValidatorRef>,
662        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
663    ) -> Self {
664        let log_state = LogState {
665            log_handler: handler,
666            log_validator: validator,
667            ingest_interceptor,
668        };
669
670        let router = self.router.nest(
671            &format!("/{HTTP_API_VERSION}"),
672            HttpServer::route_pipelines(log_state.clone()),
673        );
674        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
675        let router = router.nest(
676            &format!("/{HTTP_API_VERSION}/events"),
677            #[allow(deprecated)]
678            HttpServer::route_log_deprecated(log_state.clone()),
679        );
680
681        let router = router.nest(
682            &format!("/{HTTP_API_VERSION}/loki"),
683            HttpServer::route_loki(log_state.clone()),
684        );
685
686        let router = router.nest(
687            &format!("/{HTTP_API_VERSION}/elasticsearch"),
688            HttpServer::route_elasticsearch(log_state.clone()),
689        );
690
691        let router = router.nest(
692            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
693            Router::new()
694                .route("/", routing::get(elasticsearch::handle_get_version))
695                .with_state(log_state),
696        );
697
698        Self { router, ..self }
699    }
700
701    pub fn with_plugins(self, plugins: Plugins) -> Self {
702        Self { plugins, ..self }
703    }
704
705    pub fn with_greptime_config_options(self, opts: String) -> Self {
706        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
707            greptime_config_options: opts,
708        });
709
710        Self {
711            router: self.router.merge(config_router),
712            ..self
713        }
714    }
715
716    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
717        Self {
718            router: self.router.nest(
719                &format!("/{HTTP_API_VERSION}/jaeger"),
720                HttpServer::route_jaeger(handler),
721            ),
722            ..self
723        }
724    }
725
726    pub fn with_extra_router(self, router: Router) -> Self {
727        Self {
728            router: self.router.merge(router),
729            ..self
730        }
731    }
732
733    pub fn build(self) -> HttpServer {
734        HttpServer {
735            options: self.options,
736            user_provider: self.user_provider,
737            shutdown_tx: Mutex::new(None),
738            plugins: self.plugins,
739            router: StdMutex::new(self.router),
740            bind_addr: None,
741        }
742    }
743}
744
745impl HttpServer {
746    /// Gets the router and adds necessary root routes (health, status, dashboard).
747    pub fn make_app(&self) -> Router {
748        let mut router = {
749            let router = self.router.lock().unwrap();
750            router.clone()
751        };
752
753        router = router
754            .route(
755                "/health",
756                routing::get(handler::health).post(handler::health),
757            )
758            .route(
759                &format!("/{HTTP_API_VERSION}/health"),
760                routing::get(handler::health).post(handler::health),
761            )
762            .route(
763                "/ready",
764                routing::get(handler::health).post(handler::health),
765            );
766
767        router = router.route("/status", routing::get(handler::status));
768
769        #[cfg(feature = "dashboard")]
770        {
771            if !self.options.disable_dashboard {
772                info!("Enable dashboard service at '/dashboard'");
773                // redirect /dashboard to /dashboard/
774                router = router.route(
775                    "/dashboard",
776                    routing::get(|uri: axum::http::uri::Uri| async move {
777                        let path = uri.path();
778                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
779
780                        let new_uri = format!("{}/{}", path, query);
781                        axum::response::Redirect::permanent(&new_uri)
782                    }),
783                );
784
785                // "/dashboard" and "/dashboard/" are two different paths in Axum.
786                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
787                // So we explicitly route "/dashboard/" here.
788                router = router
789                    .route(
790                        "/dashboard/",
791                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
792                    )
793                    .route(
794                        "/dashboard/{*x}",
795                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
796                    );
797            }
798        }
799
800        // Add a layer to collect HTTP metrics for axum.
801        router = router.route_layer(middleware::from_fn(http_metrics_layer));
802
803        router
804    }
805
806    /// Attaches middlewares and debug routes to the router.
807    /// Callers should call this method after [HttpServer::make_app()].
808    pub fn build(&self, router: Router) -> Result<Router> {
809        let timeout_layer = if self.options.timeout != Duration::default() {
810            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
811        } else {
812            info!("HTTP server timeout is disabled");
813            None
814        };
815        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
816            Some(
817                ServiceBuilder::new()
818                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
819            )
820        } else {
821            info!("HTTP server body limit is disabled");
822            None
823        };
824        let cors_layer = if self.options.enable_cors {
825            Some(
826                CorsLayer::new()
827                    .allow_methods([
828                        Method::GET,
829                        Method::POST,
830                        Method::PUT,
831                        Method::DELETE,
832                        Method::HEAD,
833                    ])
834                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
835                        AllowOrigin::from(Any)
836                    } else {
837                        AllowOrigin::from(
838                            self.options
839                                .cors_allowed_origins
840                                .iter()
841                                .map(|s| {
842                                    HeaderValue::from_str(s.as_str())
843                                        .context(InvalidHeaderValueSnafu)
844                                })
845                                .collect::<Result<Vec<HeaderValue>>>()?,
846                        )
847                    })
848                    .allow_headers(Any),
849            )
850        } else {
851            info!("HTTP server cross-origin is disabled");
852            None
853        };
854
855        Ok(router
856            // middlewares
857            .layer(
858                ServiceBuilder::new()
859                    // disable on failure tracing. because printing out isn't very helpful,
860                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
861                    .layer(TraceLayer::new_for_http().on_failure(()))
862                    .option_layer(cors_layer)
863                    .option_layer(timeout_layer)
864                    .option_layer(body_limit_layer)
865                    // auth layer
866                    .layer(middleware::from_fn_with_state(
867                        AuthState::new(self.user_provider.clone()),
868                        authorize::check_http_auth,
869                    ))
870                    .layer(middleware::from_fn(hints::extract_hints))
871                    .layer(middleware::from_fn(
872                        read_preference::extract_read_preference,
873                    )),
874            )
875            // Handlers for debug, we don't expect a timeout.
876            .nest(
877                "/debug",
878                Router::new()
879                    // handler for changing log level dynamically
880                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
881                    .nest(
882                        "/prof",
883                        Router::new()
884                            .route("/cpu", routing::post(pprof::pprof_handler))
885                            .route("/mem", routing::post(mem_prof::mem_prof_handler)),
886                    ),
887            ))
888    }
889
890    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
891        Router::new()
892            .route("/metrics", routing::get(handler::metrics))
893            .with_state(metrics_handler)
894    }
895
896    fn route_loki<S>(log_state: LogState) -> Router<S> {
897        Router::new()
898            .route("/api/v1/push", routing::post(loki::loki_ingest))
899            .layer(
900                ServiceBuilder::new()
901                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
902            )
903            .with_state(log_state)
904    }
905
906    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
907        Router::new()
908            // Return fake responsefor HEAD '/' request.
909            .route(
910                "/",
911                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
912            )
913            // Return fake response for Elasticsearch version request.
914            .route("/", routing::get(elasticsearch::handle_get_version))
915            // Return fake response for Elasticsearch license request.
916            .route("/_license", routing::get(elasticsearch::handle_get_license))
917            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
918            .route(
919                "/{index}/_bulk",
920                routing::post(elasticsearch::handle_bulk_api_with_index),
921            )
922            // Return fake response for Elasticsearch ilm request.
923            .route(
924                "/_ilm/policy/{*path}",
925                routing::any((
926                    HttpStatusCode::OK,
927                    elasticsearch::elasticsearch_headers(),
928                    axum::Json(serde_json::json!({})),
929                )),
930            )
931            // Return fake response for Elasticsearch index template request.
932            .route(
933                "/_index_template/{*path}",
934                routing::any((
935                    HttpStatusCode::OK,
936                    elasticsearch::elasticsearch_headers(),
937                    axum::Json(serde_json::json!({})),
938                )),
939            )
940            // Return fake response for Elasticsearch ingest pipeline request.
941            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
942            .route(
943                "/_ingest/{*path}",
944                routing::any((
945                    HttpStatusCode::OK,
946                    elasticsearch::elasticsearch_headers(),
947                    axum::Json(serde_json::json!({})),
948                )),
949            )
950            // Return fake response for Elasticsearch nodes discovery request.
951            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
952            .route(
953                "/_nodes/{*path}",
954                routing::any((
955                    HttpStatusCode::OK,
956                    elasticsearch::elasticsearch_headers(),
957                    axum::Json(serde_json::json!({})),
958                )),
959            )
960            // Return fake response for Logstash APIs requests.
961            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
962            .route(
963                "/logstash/{*path}",
964                routing::any((
965                    HttpStatusCode::OK,
966                    elasticsearch::elasticsearch_headers(),
967                    axum::Json(serde_json::json!({})),
968                )),
969            )
970            .route(
971                "/_logstash/{*path}",
972                routing::any((
973                    HttpStatusCode::OK,
974                    elasticsearch::elasticsearch_headers(),
975                    axum::Json(serde_json::json!({})),
976                )),
977            )
978            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
979            .with_state(log_state)
980    }
981
982    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
983    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
984        Router::new()
985            .route("/logs", routing::post(event::log_ingester))
986            .route(
987                "/pipelines/{pipeline_name}",
988                routing::get(event::query_pipeline),
989            )
990            .route(
991                "/pipelines/{pipeline_name}",
992                routing::post(event::add_pipeline),
993            )
994            .route(
995                "/pipelines/{pipeline_name}",
996                routing::delete(event::delete_pipeline),
997            )
998            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
999            .layer(
1000                ServiceBuilder::new()
1001                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1002            )
1003            .with_state(log_state)
1004    }
1005
1006    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1007        Router::new()
1008            .route("/ingest", routing::post(event::log_ingester))
1009            .route(
1010                "/pipelines/{pipeline_name}",
1011                routing::get(event::query_pipeline),
1012            )
1013            .route(
1014                "/pipelines/{pipeline_name}",
1015                routing::post(event::add_pipeline),
1016            )
1017            .route(
1018                "/pipelines/{pipeline_name}",
1019                routing::delete(event::delete_pipeline),
1020            )
1021            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1022            .layer(
1023                ServiceBuilder::new()
1024                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1025            )
1026            .with_state(log_state)
1027    }
1028
1029    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1030        Router::new()
1031            .route("/sql", routing::get(handler::sql).post(handler::sql))
1032            .route(
1033                "/sql/parse",
1034                routing::get(handler::sql_parse).post(handler::sql_parse),
1035            )
1036            .route(
1037                "/promql",
1038                routing::get(handler::promql).post(handler::promql),
1039            )
1040            .with_state(api_state)
1041    }
1042
1043    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1044        Router::new()
1045            .route("/logs", routing::get(logs::logs).post(logs::logs))
1046            .with_state(log_handler)
1047    }
1048
1049    /// Route Prometheus [HTTP API].
1050    ///
1051    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1052    fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1053        Router::new()
1054            .route(
1055                "/format_query",
1056                routing::post(format_query).get(format_query),
1057            )
1058            .route("/status/buildinfo", routing::get(build_info_query))
1059            .route("/query", routing::post(instant_query).get(instant_query))
1060            .route("/query_range", routing::post(range_query).get(range_query))
1061            .route("/labels", routing::post(labels_query).get(labels_query))
1062            .route("/series", routing::post(series_query).get(series_query))
1063            .route("/parse_query", routing::post(parse_query).get(parse_query))
1064            .route(
1065                "/label/{label_name}/values",
1066                routing::get(label_values_query),
1067            )
1068            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1069            .with_state(prometheus_handler)
1070    }
1071
1072    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1073    /// called `prom_store`.
1074    ///
1075    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1076    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1077    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1078        Router::new()
1079            .route("/read", routing::post(prom_store::remote_read))
1080            .route("/write", routing::post(prom_store::remote_write))
1081            .with_state(state)
1082    }
1083
1084    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1085        Router::new()
1086            .route("/write", routing::post(influxdb_write_v1))
1087            .route("/api/v2/write", routing::post(influxdb_write_v2))
1088            .layer(
1089                ServiceBuilder::new()
1090                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1091            )
1092            .route("/ping", routing::get(influxdb_ping))
1093            .route("/health", routing::get(influxdb_health))
1094            .with_state(influxdb_handler)
1095    }
1096
1097    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1098        Router::new()
1099            .route("/api/put", routing::post(opentsdb::put))
1100            .with_state(opentsdb_handler)
1101    }
1102
1103    fn route_otlp<S>(otlp_handler: OpenTelemetryProtocolHandlerRef) -> Router<S> {
1104        Router::new()
1105            .route("/v1/metrics", routing::post(otlp::metrics))
1106            .route("/v1/traces", routing::post(otlp::traces))
1107            .route("/v1/logs", routing::post(otlp::logs))
1108            .layer(
1109                ServiceBuilder::new()
1110                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1111            )
1112            .with_state(otlp_handler)
1113    }
1114
1115    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1116        Router::new()
1117            .route("/config", routing::get(handler::config))
1118            .with_state(state)
1119    }
1120
1121    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1122        Router::new()
1123            .route("/api/services", routing::get(jaeger::handle_get_services))
1124            .route(
1125                "/api/services/{service_name}/operations",
1126                routing::get(jaeger::handle_get_operations_by_service),
1127            )
1128            .route(
1129                "/api/operations",
1130                routing::get(jaeger::handle_get_operations),
1131            )
1132            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1133            .route(
1134                "/api/traces/{trace_id}",
1135                routing::get(jaeger::handle_get_trace),
1136            )
1137            .with_state(handler)
1138    }
1139}
1140
1141pub const HTTP_SERVER: &str = "HTTP_SERVER";
1142
1143#[async_trait]
1144impl Server for HttpServer {
1145    async fn shutdown(&self) -> Result<()> {
1146        let mut shutdown_tx = self.shutdown_tx.lock().await;
1147        if let Some(tx) = shutdown_tx.take() {
1148            if tx.send(()).is_err() {
1149                info!("Receiver dropped, the HTTP server has already exited");
1150            }
1151        }
1152        info!("Shutdown HTTP server");
1153
1154        Ok(())
1155    }
1156
1157    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1158        let (tx, rx) = oneshot::channel();
1159        let serve = {
1160            let mut shutdown_tx = self.shutdown_tx.lock().await;
1161            ensure!(
1162                shutdown_tx.is_none(),
1163                AlreadyStartedSnafu { server: "HTTP" }
1164            );
1165
1166            let mut app = self.make_app();
1167            if let Some(configurator) = self.plugins.get::<ConfiguratorRef>() {
1168                app = configurator.config_http(app);
1169            }
1170            let app = self.build(app)?;
1171            let listener = tokio::net::TcpListener::bind(listening)
1172                .await
1173                .context(AddressBindSnafu { addr: listening })?
1174                .tap_io(|tcp_stream| {
1175                    if let Err(e) = tcp_stream.set_nodelay(true) {
1176                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1177                    }
1178                });
1179            let serve = axum::serve(listener, app.into_make_service());
1180
1181            // FIXME(yingwen): Support keepalive.
1182            // See:
1183            // - https://github.com/tokio-rs/axum/discussions/2939
1184            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1185            // let server = axum::Server::try_bind(&listening)
1186            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1187            //     .tcp_nodelay(true)
1188            //     // Enable TCP keepalive to close the dangling established connections.
1189            //     // It's configured to let the keepalive probes first send after the connection sits
1190            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1191            //     // So the connection will be closed after roughly 1 hour.
1192            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1193            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1194            //     .tcp_keepalive_retries(Some(6))
1195            //     .serve(app.into_make_service());
1196
1197            *shutdown_tx = Some(tx);
1198
1199            serve
1200        };
1201        let listening = serve.local_addr().context(InternalIoSnafu)?;
1202        info!("HTTP server is bound to {}", listening);
1203
1204        common_runtime::spawn_global(async move {
1205            if let Err(e) = serve
1206                .with_graceful_shutdown(rx.map(drop))
1207                .await
1208                .context(InternalIoSnafu)
1209            {
1210                error!(e; "Failed to shutdown http server");
1211            }
1212        });
1213
1214        self.bind_addr = Some(listening);
1215        Ok(())
1216    }
1217
1218    fn name(&self) -> &str {
1219        HTTP_SERVER
1220    }
1221
1222    fn bind_addr(&self) -> Option<SocketAddr> {
1223        self.bind_addr
1224    }
1225}
1226
1227#[cfg(test)]
1228mod test {
1229    use std::future::pending;
1230    use std::io::Cursor;
1231    use std::sync::Arc;
1232
1233    use arrow_ipc::reader::FileReader;
1234    use arrow_schema::DataType;
1235    use axum::handler::Handler;
1236    use axum::http::StatusCode;
1237    use axum::routing::get;
1238    use common_query::Output;
1239    use common_recordbatch::RecordBatches;
1240    use datafusion_expr::LogicalPlan;
1241    use datatypes::prelude::*;
1242    use datatypes::schema::{ColumnSchema, Schema};
1243    use datatypes::vectors::{StringVector, UInt32Vector};
1244    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1245    use query::parser::PromQuery;
1246    use query::query_engine::DescribeResult;
1247    use session::context::QueryContextRef;
1248    use tokio::sync::mpsc;
1249    use tokio::time::Instant;
1250
1251    use super::*;
1252    use crate::error::Error;
1253    use crate::http::test_helpers::TestClient;
1254    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1255
1256    struct DummyInstance {
1257        _tx: mpsc::Sender<(String, Vec<u8>)>,
1258    }
1259
1260    #[async_trait]
1261    impl SqlQueryHandler for DummyInstance {
1262        type Error = Error;
1263
1264        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1265            unimplemented!()
1266        }
1267
1268        async fn do_promql_query(
1269            &self,
1270            _: &PromQuery,
1271            _: QueryContextRef,
1272        ) -> Vec<std::result::Result<Output, Self::Error>> {
1273            unimplemented!()
1274        }
1275
1276        async fn do_exec_plan(
1277            &self,
1278            _plan: LogicalPlan,
1279            _query_ctx: QueryContextRef,
1280        ) -> std::result::Result<Output, Self::Error> {
1281            unimplemented!()
1282        }
1283
1284        async fn do_describe(
1285            &self,
1286            _stmt: sql::statements::statement::Statement,
1287            _query_ctx: QueryContextRef,
1288        ) -> Result<Option<DescribeResult>> {
1289            unimplemented!()
1290        }
1291
1292        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1293            Ok(true)
1294        }
1295    }
1296
1297    fn timeout() -> DynamicTimeoutLayer {
1298        DynamicTimeoutLayer::new(Duration::from_millis(10))
1299    }
1300
1301    async fn forever() {
1302        pending().await
1303    }
1304
1305    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1306        make_test_app_custom(tx, HttpOptions::default())
1307    }
1308
1309    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1310        let instance = Arc::new(DummyInstance { _tx: tx });
1311        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1312        let server = HttpServerBuilder::new(options)
1313            .with_sql_handler(sql_instance)
1314            .build();
1315        server.build(server.make_app()).unwrap().route(
1316            "/test/timeout",
1317            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1318        )
1319    }
1320
1321    #[tokio::test]
1322    pub async fn test_cors() {
1323        // cors is on by default
1324        let (tx, _rx) = mpsc::channel(100);
1325        let app = make_test_app(tx);
1326        let client = TestClient::new(app).await;
1327
1328        let res = client.get("/health").send().await;
1329
1330        assert_eq!(res.status(), StatusCode::OK);
1331        assert_eq!(
1332            res.headers()
1333                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1334                .expect("expect cors header origin"),
1335            "*"
1336        );
1337
1338        let res = client.get("/v1/health").send().await;
1339
1340        assert_eq!(res.status(), StatusCode::OK);
1341        assert_eq!(
1342            res.headers()
1343                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1344                .expect("expect cors header origin"),
1345            "*"
1346        );
1347
1348        let res = client
1349            .options("/health")
1350            .header("Access-Control-Request-Headers", "x-greptime-auth")
1351            .header("Access-Control-Request-Method", "DELETE")
1352            .header("Origin", "https://example.com")
1353            .send()
1354            .await;
1355        assert_eq!(res.status(), StatusCode::OK);
1356        assert_eq!(
1357            res.headers()
1358                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1359                .expect("expect cors header origin"),
1360            "*"
1361        );
1362        assert_eq!(
1363            res.headers()
1364                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1365                .expect("expect cors header headers"),
1366            "*"
1367        );
1368        assert_eq!(
1369            res.headers()
1370                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1371                .expect("expect cors header methods"),
1372            "GET,POST,PUT,DELETE,HEAD"
1373        );
1374    }
1375
1376    #[tokio::test]
1377    pub async fn test_cors_custom_origins() {
1378        // cors is on by default
1379        let (tx, _rx) = mpsc::channel(100);
1380        let origin = "https://example.com";
1381
1382        let options = HttpOptions {
1383            cors_allowed_origins: vec![origin.to_string()],
1384            ..Default::default()
1385        };
1386
1387        let app = make_test_app_custom(tx, options);
1388        let client = TestClient::new(app).await;
1389
1390        let res = client.get("/health").header("Origin", origin).send().await;
1391
1392        assert_eq!(res.status(), StatusCode::OK);
1393        assert_eq!(
1394            res.headers()
1395                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1396                .expect("expect cors header origin"),
1397            origin
1398        );
1399
1400        let res = client
1401            .get("/health")
1402            .header("Origin", "https://notallowed.com")
1403            .send()
1404            .await;
1405
1406        assert_eq!(res.status(), StatusCode::OK);
1407        assert!(!res
1408            .headers()
1409            .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN));
1410    }
1411
1412    #[tokio::test]
1413    pub async fn test_cors_disabled() {
1414        // cors is on by default
1415        let (tx, _rx) = mpsc::channel(100);
1416
1417        let options = HttpOptions {
1418            enable_cors: false,
1419            ..Default::default()
1420        };
1421
1422        let app = make_test_app_custom(tx, options);
1423        let client = TestClient::new(app).await;
1424
1425        let res = client.get("/health").send().await;
1426
1427        assert_eq!(res.status(), StatusCode::OK);
1428        assert!(!res
1429            .headers()
1430            .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN));
1431    }
1432
1433    #[test]
1434    fn test_http_options_default() {
1435        let default = HttpOptions::default();
1436        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1437        assert_eq!(Duration::from_secs(0), default.timeout)
1438    }
1439
1440    #[tokio::test]
1441    async fn test_http_server_request_timeout() {
1442        common_telemetry::init_default_ut_logging();
1443
1444        let (tx, _rx) = mpsc::channel(100);
1445        let app = make_test_app(tx);
1446        let client = TestClient::new(app).await;
1447        let res = client.get("/test/timeout").send().await;
1448        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1449
1450        let now = Instant::now();
1451        let res = client
1452            .get("/test/timeout")
1453            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1454            .send()
1455            .await;
1456        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1457        let elapsed = now.elapsed();
1458        assert!(elapsed > Duration::from_millis(15));
1459
1460        tokio::time::timeout(
1461            Duration::from_millis(15),
1462            client
1463                .get("/test/timeout")
1464                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1465                .send(),
1466        )
1467        .await
1468        .unwrap_err();
1469
1470        tokio::time::timeout(
1471            Duration::from_millis(15),
1472            client
1473                .get("/test/timeout")
1474                .header(
1475                    GREPTIME_DB_HEADER_TIMEOUT,
1476                    humantime::format_duration(Duration::default()).to_string(),
1477                )
1478                .send(),
1479        )
1480        .await
1481        .unwrap_err();
1482    }
1483
1484    #[tokio::test]
1485    async fn test_schema_for_empty_response() {
1486        let column_schemas = vec![
1487            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1488            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1489        ];
1490        let schema = Arc::new(Schema::new(column_schemas));
1491
1492        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1493        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1494
1495        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1496        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1497            let json_output = &json_resp.output[0];
1498            if let GreptimeQueryOutput::Records(r) = json_output {
1499                assert_eq!(r.num_rows(), 0);
1500                assert_eq!(r.num_cols(), 2);
1501                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1502                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1503            } else {
1504                panic!("invalid output type");
1505            }
1506        } else {
1507            panic!("invalid format")
1508        }
1509    }
1510
1511    #[tokio::test]
1512    async fn test_recordbatches_conversion() {
1513        let column_schemas = vec![
1514            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1515            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1516        ];
1517        let schema = Arc::new(Schema::new(column_schemas));
1518        let columns: Vec<VectorRef> = vec![
1519            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1520            Arc::new(StringVector::from(vec![
1521                None,
1522                Some("hello"),
1523                Some("greptime"),
1524                None,
1525            ])),
1526        ];
1527        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1528
1529        for format in [
1530            ResponseFormat::GreptimedbV1,
1531            ResponseFormat::InfluxdbV1,
1532            ResponseFormat::Csv(true, true),
1533            ResponseFormat::Table,
1534            ResponseFormat::Arrow,
1535            ResponseFormat::Json,
1536            ResponseFormat::Null,
1537        ] {
1538            let recordbatches =
1539                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1540            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1541            let json_resp = match format {
1542                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1543                ResponseFormat::Csv(with_names, with_types) => {
1544                    CsvResponse::from_output(outputs, with_names, with_types).await
1545                }
1546                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1547                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1548                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1549                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1550                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1551            };
1552
1553            match json_resp {
1554                HttpResponse::GreptimedbV1(resp) => {
1555                    let json_output = &resp.output[0];
1556                    if let GreptimeQueryOutput::Records(r) = json_output {
1557                        assert_eq!(r.num_rows(), 4);
1558                        assert_eq!(r.num_cols(), 2);
1559                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1560                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1561                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1562                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1563                    } else {
1564                        panic!("invalid output type");
1565                    }
1566                }
1567                HttpResponse::InfluxdbV1(resp) => {
1568                    let json_output = &resp.results()[0];
1569                    assert_eq!(json_output.num_rows(), 4);
1570                    assert_eq!(json_output.num_cols(), 2);
1571                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1572                    assert_eq!(
1573                        json_output.series[0].values[0][0],
1574                        serde_json::Value::from(1)
1575                    );
1576                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1577                }
1578                HttpResponse::Csv(resp) => {
1579                    let output = &resp.output()[0];
1580                    if let GreptimeQueryOutput::Records(r) = output {
1581                        assert_eq!(r.num_rows(), 4);
1582                        assert_eq!(r.num_cols(), 2);
1583                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1584                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1585                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1586                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1587                    } else {
1588                        panic!("invalid output type");
1589                    }
1590                }
1591
1592                HttpResponse::Table(resp) => {
1593                    let output = &resp.output()[0];
1594                    if let GreptimeQueryOutput::Records(r) = output {
1595                        assert_eq!(r.num_rows(), 4);
1596                        assert_eq!(r.num_cols(), 2);
1597                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1598                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1599                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1600                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1601                    } else {
1602                        panic!("invalid output type");
1603                    }
1604                }
1605
1606                HttpResponse::Arrow(resp) => {
1607                    let output = resp.data;
1608                    let mut reader =
1609                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1610                    let schema = reader.schema();
1611                    assert_eq!(schema.fields[0].name(), "numbers");
1612                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1613                    assert_eq!(schema.fields[1].name(), "strings");
1614                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1615
1616                    let rb = reader.next().unwrap().expect("read record batch failed");
1617                    assert_eq!(rb.num_columns(), 2);
1618                    assert_eq!(rb.num_rows(), 4);
1619                }
1620
1621                HttpResponse::Json(resp) => {
1622                    let output = &resp.output()[0];
1623                    if let GreptimeQueryOutput::Records(r) = output {
1624                        assert_eq!(r.num_rows(), 4);
1625                        assert_eq!(r.num_cols(), 2);
1626                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1627                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1628                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1629                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1630                    } else {
1631                        panic!("invalid output type");
1632                    }
1633                }
1634
1635                HttpResponse::Null(resp) => {
1636                    assert_eq!(resp.rows(), 4);
1637                }
1638
1639                HttpResponse::Error(err) => unreachable!("{err:?}"),
1640            }
1641        }
1642    }
1643
1644    #[test]
1645    fn test_response_format_misc() {
1646        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1647        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1648        assert_eq!(
1649            ResponseFormat::parse("csv"),
1650            Some(ResponseFormat::Csv(false, false))
1651        );
1652        assert_eq!(
1653            ResponseFormat::parse("csvwithnames"),
1654            Some(ResponseFormat::Csv(true, false))
1655        );
1656        assert_eq!(
1657            ResponseFormat::parse("csvwithnamesandtypes"),
1658            Some(ResponseFormat::Csv(true, true))
1659        );
1660        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1661        assert_eq!(
1662            ResponseFormat::parse("greptimedb_v1"),
1663            Some(ResponseFormat::GreptimedbV1)
1664        );
1665        assert_eq!(
1666            ResponseFormat::parse("influxdb_v1"),
1667            Some(ResponseFormat::InfluxdbV1)
1668        );
1669        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1670        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1671
1672        // invalid formats
1673        assert_eq!(ResponseFormat::parse("invalid"), None);
1674        assert_eq!(ResponseFormat::parse(""), None);
1675        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1676
1677        // as str
1678        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1679        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1680        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1681        assert_eq!(ResponseFormat::Table.as_str(), "table");
1682        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1683        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1684        assert_eq!(ResponseFormat::Json.as_str(), "json");
1685        assert_eq!(ResponseFormat::Null.as_str(), "null");
1686        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1687    }
1688}