Skip to main content

servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::{Arc, Mutex as StdMutex};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use snafu::{ResultExt, ensure};
44use tokio::sync::Mutex;
45use tokio::sync::oneshot::{self, Sender};
46use tonic::codegen::Service;
47use tower::{Layer, ServiceBuilder};
48use tower_http::compression::CompressionLayer;
49use tower_http::cors::{AllowOrigin, Any, CorsLayer};
50use tower_http::decompression::RequestDecompressionLayer;
51use tower_http::trace::TraceLayer;
52
53use self::authorize::AuthState;
54use self::result::table_result::TableResponse;
55use crate::configurator::HttpConfiguratorRef;
56use crate::elasticsearch;
57use crate::error::{
58    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
59    OtherSnafu, Result,
60};
61use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
62use crate::http::otlp::OtlpState;
63use crate::http::prom_store::PromStoreState;
64use crate::http::prometheus::{
65    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
66    range_query, series_query,
67};
68use crate::http::result::arrow_result::ArrowResponse;
69use crate::http::result::csv_result::CsvResponse;
70use crate::http::result::error_result::ErrorResponse;
71use crate::http::result::greptime_result_v1::GreptimedbV1Response;
72use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
73use crate::http::result::json_result::JsonResponse;
74use crate::http::result::null_result::NullResponse;
75use crate::interceptor::LogIngestInterceptorRef;
76use crate::metrics::http_metrics_layer;
77use crate::metrics_handler::MetricsHandler;
78use crate::pending_rows_batcher::PendingRowsBatcher;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    DashboardHandlerRef, InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_memory_limiter::ServerMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod client_ip;
117use crate::prom_remote_write::validation::PromValidationMode;
118mod hints;
119mod read_preference;
120#[cfg(any(test, feature = "testing"))]
121pub mod test_helpers;
122
123pub const HTTP_API_VERSION: &str = "v1";
124pub const HTTP_API_PREFIX: &str = "/v1/";
125/// Default http body limit (64M).
126const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
127
128/// Authorization header
129pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
130
131// TODO(fys): This is a temporary workaround, it will be improved later
132pub static PUBLIC_API_PREFIX: [&str; 3] =
133    ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
134
135#[derive(Default)]
136pub struct HttpServer {
137    router: StdMutex<Router>,
138    shutdown_tx: Mutex<Option<Sender<()>>>,
139    user_provider: Option<UserProviderRef>,
140    memory_limiter: ServerMemoryLimiter,
141
142    // plugins
143    plugins: Plugins,
144
145    // server configs
146    options: HttpOptions,
147    bind_addr: Option<SocketAddr>,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(default)]
152pub struct HttpOptions {
153    pub addr: String,
154
155    #[serde(with = "humantime_serde")]
156    pub timeout: Duration,
157
158    #[serde(skip)]
159    pub disable_dashboard: bool,
160
161    pub body_limit: ReadableSize,
162
163    /// Validation mode while decoding Prometheus remote write requests.
164    pub prom_validation_mode: PromValidationMode,
165
166    pub cors_allowed_origins: Vec<String>,
167
168    pub enable_cors: bool,
169}
170
171impl Default for HttpOptions {
172    fn default() -> Self {
173        Self {
174            addr: "127.0.0.1:4000".to_string(),
175            timeout: Duration::from_secs(0),
176            disable_dashboard: false,
177            body_limit: DEFAULT_BODY_LIMIT,
178            cors_allowed_origins: Vec::new(),
179            enable_cors: true,
180            prom_validation_mode: PromValidationMode::Strict,
181        }
182    }
183}
184
185#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
186pub struct ColumnSchema {
187    name: String,
188    data_type: String,
189}
190
191impl ColumnSchema {
192    pub fn new(name: String, data_type: String) -> ColumnSchema {
193        ColumnSchema { name, data_type }
194    }
195}
196
197#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
198pub struct OutputSchema {
199    column_schemas: Vec<ColumnSchema>,
200}
201
202impl OutputSchema {
203    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
204        OutputSchema {
205            column_schemas: columns,
206        }
207    }
208}
209
210impl From<SchemaRef> for OutputSchema {
211    fn from(schema: SchemaRef) -> OutputSchema {
212        OutputSchema {
213            column_schemas: schema
214                .column_schemas()
215                .iter()
216                .map(|cs| ColumnSchema {
217                    name: cs.name.clone(),
218                    data_type: cs.data_type.name(),
219                })
220                .collect(),
221        }
222    }
223}
224
225#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
226pub struct HttpRecordsOutput {
227    schema: OutputSchema,
228    rows: Vec<Vec<Value>>,
229    // total_rows is equal to rows.len() in most cases,
230    // the Dashboard query result may be truncated, so we need to return the total_rows.
231    #[serde(default)]
232    total_rows: usize,
233
234    // plan level execution metrics
235    #[serde(skip_serializing_if = "HashMap::is_empty")]
236    #[serde(default)]
237    metrics: HashMap<String, Value>,
238}
239
240impl HttpRecordsOutput {
241    pub fn num_rows(&self) -> usize {
242        self.rows.len()
243    }
244
245    pub fn num_cols(&self) -> usize {
246        self.schema.column_schemas.len()
247    }
248
249    pub fn schema(&self) -> &OutputSchema {
250        &self.schema
251    }
252
253    pub fn rows(&self) -> &Vec<Vec<Value>> {
254        &self.rows
255    }
256}
257
258impl HttpRecordsOutput {
259    pub fn try_new(
260        schema: SchemaRef,
261        recordbatches: Vec<RecordBatch>,
262    ) -> std::result::Result<HttpRecordsOutput, Error> {
263        if recordbatches.is_empty() {
264            Ok(HttpRecordsOutput {
265                schema: OutputSchema::from(schema),
266                rows: vec![],
267                total_rows: 0,
268                metrics: Default::default(),
269            })
270        } else {
271            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
272            let mut rows = Vec::with_capacity(num_rows);
273
274            for recordbatch in recordbatches {
275                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
276                writer.write(recordbatch, &mut rows)?;
277            }
278
279            Ok(HttpRecordsOutput {
280                schema: OutputSchema::from(schema),
281                total_rows: rows.len(),
282                rows,
283                metrics: Default::default(),
284            })
285        }
286    }
287}
288
289#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
290#[serde(rename_all = "lowercase")]
291pub enum GreptimeQueryOutput {
292    AffectedRows(usize),
293    Records(HttpRecordsOutput),
294}
295
296/// It allows the results of SQL queries to be presented in different formats.
297#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
298pub enum ResponseFormat {
299    Arrow,
300    // (with_names, with_types)
301    Csv(bool, bool),
302    Table,
303    #[default]
304    GreptimedbV1,
305    InfluxdbV1,
306    Json,
307    Null,
308}
309
310impl ResponseFormat {
311    pub fn parse(s: &str) -> Option<Self> {
312        match s {
313            "arrow" => Some(ResponseFormat::Arrow),
314            "csv" => Some(ResponseFormat::Csv(false, false)),
315            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
316            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
317            "table" => Some(ResponseFormat::Table),
318            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
319            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
320            "json" => Some(ResponseFormat::Json),
321            "null" => Some(ResponseFormat::Null),
322            _ => None,
323        }
324    }
325
326    pub fn as_str(&self) -> &'static str {
327        match self {
328            ResponseFormat::Arrow => "arrow",
329            ResponseFormat::Csv(_, _) => "csv",
330            ResponseFormat::Table => "table",
331            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
332            ResponseFormat::InfluxdbV1 => "influxdb_v1",
333            ResponseFormat::Json => "json",
334            ResponseFormat::Null => "null",
335        }
336    }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
340pub enum Epoch {
341    Nanosecond,
342    Microsecond,
343    Millisecond,
344    Second,
345}
346
347impl Epoch {
348    pub fn parse(s: &str) -> Option<Epoch> {
349        // Both u and µ indicate microseconds.
350        // epoch = [ns,u,µ,ms,s],
351        // For details, see the Influxdb documents.
352        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
353        match s {
354            "ns" => Some(Epoch::Nanosecond),
355            "u" | "µ" => Some(Epoch::Microsecond),
356            "ms" => Some(Epoch::Millisecond),
357            "s" => Some(Epoch::Second),
358            _ => None, // just returns None for other cases
359        }
360    }
361
362    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
363        match self {
364            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
365            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
366            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
367            Epoch::Second => ts.convert_to(TimeUnit::Second),
368        }
369    }
370}
371
372impl Display for Epoch {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        match self {
375            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
376            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
377            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
378            Epoch::Second => write!(f, "Epoch::Second"),
379        }
380    }
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub enum HttpResponse {
385    Arrow(ArrowResponse),
386    Csv(CsvResponse),
387    Table(TableResponse),
388    Error(ErrorResponse),
389    GreptimedbV1(GreptimedbV1Response),
390    InfluxdbV1(InfluxdbV1Response),
391    Json(JsonResponse),
392    Null(NullResponse),
393}
394
395impl HttpResponse {
396    pub fn with_execution_time(self, execution_time: u64) -> Self {
397        match self {
398            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
399            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
400            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
401            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
402            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
403            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
404            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
405            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
406        }
407    }
408
409    pub fn with_limit(self, limit: usize) -> Self {
410        match self {
411            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
412            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
413            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
414            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
415            _ => self,
416        }
417    }
418}
419
420pub fn process_with_limit(
421    mut outputs: Vec<GreptimeQueryOutput>,
422    limit: usize,
423) -> Vec<GreptimeQueryOutput> {
424    outputs
425        .drain(..)
426        .map(|data| match data {
427            GreptimeQueryOutput::Records(mut records) => {
428                if records.rows.len() > limit {
429                    records.rows.truncate(limit);
430                    records.total_rows = limit;
431                }
432                GreptimeQueryOutput::Records(records)
433            }
434            _ => data,
435        })
436        .collect()
437}
438
439impl IntoResponse for HttpResponse {
440    fn into_response(self) -> Response {
441        match self {
442            HttpResponse::Arrow(resp) => resp.into_response(),
443            HttpResponse::Csv(resp) => resp.into_response(),
444            HttpResponse::Table(resp) => resp.into_response(),
445            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
446            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
447            HttpResponse::Json(resp) => resp.into_response(),
448            HttpResponse::Null(resp) => resp.into_response(),
449            HttpResponse::Error(resp) => resp.into_response(),
450        }
451    }
452}
453
454impl From<ArrowResponse> for HttpResponse {
455    fn from(value: ArrowResponse) -> Self {
456        HttpResponse::Arrow(value)
457    }
458}
459
460impl From<CsvResponse> for HttpResponse {
461    fn from(value: CsvResponse) -> Self {
462        HttpResponse::Csv(value)
463    }
464}
465
466impl From<TableResponse> for HttpResponse {
467    fn from(value: TableResponse) -> Self {
468        HttpResponse::Table(value)
469    }
470}
471
472impl From<ErrorResponse> for HttpResponse {
473    fn from(value: ErrorResponse) -> Self {
474        HttpResponse::Error(value)
475    }
476}
477
478impl From<GreptimedbV1Response> for HttpResponse {
479    fn from(value: GreptimedbV1Response) -> Self {
480        HttpResponse::GreptimedbV1(value)
481    }
482}
483
484impl From<InfluxdbV1Response> for HttpResponse {
485    fn from(value: InfluxdbV1Response) -> Self {
486        HttpResponse::InfluxdbV1(value)
487    }
488}
489
490impl From<JsonResponse> for HttpResponse {
491    fn from(value: JsonResponse) -> Self {
492        HttpResponse::Json(value)
493    }
494}
495
496impl From<NullResponse> for HttpResponse {
497    fn from(value: NullResponse) -> Self {
498        HttpResponse::Null(value)
499    }
500}
501
502#[derive(Clone)]
503pub struct ApiState {
504    pub sql_handler: ServerSqlQueryHandlerRef,
505}
506
507#[derive(Clone)]
508pub struct GreptimeOptionsConfigState {
509    pub greptime_config_options: String,
510}
511
512#[derive(Clone)]
513pub struct DashboardState {
514    pub handler: DashboardHandlerRef,
515}
516
517pub struct HttpServerBuilder {
518    options: HttpOptions,
519    plugins: Plugins,
520    user_provider: Option<UserProviderRef>,
521    router: Router,
522    memory_limiter: ServerMemoryLimiter,
523}
524
525impl HttpServerBuilder {
526    pub fn new(options: HttpOptions) -> Self {
527        Self {
528            options,
529            plugins: Plugins::default(),
530            user_provider: None,
531            router: Router::new(),
532            memory_limiter: ServerMemoryLimiter::default(),
533        }
534    }
535
536    /// Set a global memory limiter for all server protocols.
537    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
538        self.memory_limiter = limiter;
539        self
540    }
541
542    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
543        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
544
545        Self {
546            router: self
547                .router
548                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
549            ..self
550        }
551    }
552
553    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
554        let logs_router = HttpServer::route_logs(logs_handler);
555
556        Self {
557            router: self
558                .router
559                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
560            ..self
561        }
562    }
563
564    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
565        Self {
566            router: self.router.nest(
567                &format!("/{HTTP_API_VERSION}/opentsdb"),
568                HttpServer::route_opentsdb(handler),
569            ),
570            ..self
571        }
572    }
573
574    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
575        Self {
576            router: self.router.nest(
577                &format!("/{HTTP_API_VERSION}/influxdb"),
578                HttpServer::route_influxdb(handler),
579            ),
580            ..self
581        }
582    }
583
584    pub fn with_prom_handler(
585        self,
586        handler: PromStoreProtocolHandlerRef,
587        pipeline_handler: Option<PipelineHandlerRef>,
588        prom_store_with_metric_engine: bool,
589        prom_validation_mode: PromValidationMode,
590        pending_rows_batcher: Option<Arc<PendingRowsBatcher>>,
591    ) -> Self {
592        let state = PromStoreState {
593            prom_store_handler: handler,
594            pipeline_handler,
595            prom_store_with_metric_engine,
596            prom_validation_mode,
597            pending_rows_batcher,
598        };
599
600        Self {
601            router: self.router.nest(
602                &format!("/{HTTP_API_VERSION}/prometheus"),
603                HttpServer::route_prom(state),
604            ),
605            ..self
606        }
607    }
608
609    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
610        Self {
611            router: self.router.nest(
612                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
613                HttpServer::route_prometheus(handler),
614            ),
615            ..self
616        }
617    }
618
619    pub fn with_otlp_handler(
620        self,
621        handler: OpenTelemetryProtocolHandlerRef,
622        with_metric_engine: bool,
623    ) -> Self {
624        Self {
625            router: self.router.nest(
626                &format!("/{HTTP_API_VERSION}/otlp"),
627                HttpServer::route_otlp(handler, with_metric_engine),
628            ),
629            ..self
630        }
631    }
632
633    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
634        Self {
635            user_provider: Some(user_provider),
636            ..self
637        }
638    }
639
640    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
641        Self {
642            router: self.router.merge(HttpServer::route_metrics(handler)),
643            ..self
644        }
645    }
646
647    pub fn with_log_ingest_handler(
648        self,
649        handler: PipelineHandlerRef,
650        validator: Option<LogValidatorRef>,
651        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
652    ) -> Self {
653        let log_state = LogState {
654            log_handler: handler,
655            log_validator: validator,
656            ingest_interceptor,
657        };
658
659        let router = self.router.nest(
660            &format!("/{HTTP_API_VERSION}"),
661            HttpServer::route_pipelines(log_state.clone()),
662        );
663        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
664        let router = router.nest(
665            &format!("/{HTTP_API_VERSION}/events"),
666            #[allow(deprecated)]
667            HttpServer::route_log_deprecated(log_state.clone()),
668        );
669
670        let router = router.nest(
671            &format!("/{HTTP_API_VERSION}/loki"),
672            HttpServer::route_loki(log_state.clone()),
673        );
674
675        let router = router.nest(
676            &format!("/{HTTP_API_VERSION}/elasticsearch"),
677            HttpServer::route_elasticsearch(log_state.clone()),
678        );
679
680        let router = router.nest(
681            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
682            Router::new()
683                .route("/", routing::get(elasticsearch::handle_get_version))
684                .with_state(log_state),
685        );
686
687        Self { router, ..self }
688    }
689
690    pub fn with_plugins(self, plugins: Plugins) -> Self {
691        Self { plugins, ..self }
692    }
693
694    pub fn with_greptime_config_options(self, opts: String) -> Self {
695        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
696            greptime_config_options: opts,
697        });
698
699        Self {
700            router: self.router.merge(config_router),
701            ..self
702        }
703    }
704
705    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
706        Self {
707            router: self.router.nest(
708                &format!("/{HTTP_API_VERSION}/jaeger"),
709                HttpServer::route_jaeger(handler),
710            ),
711            ..self
712        }
713    }
714
715    pub fn with_dashboard_handler(self, handler: DashboardHandlerRef) -> Self {
716        Self {
717            router: self.router.nest(
718                &format!("/{HTTP_API_VERSION}/dashboards"),
719                HttpServer::route_dashboard(handler),
720            ),
721            ..self
722        }
723    }
724
725    pub fn with_extra_router(self, router: Router) -> Self {
726        Self {
727            router: self.router.merge(router),
728            ..self
729        }
730    }
731
732    pub fn add_layer<L>(self, layer: L) -> Self
733    where
734        L: Layer<Route> + Clone + Send + Sync + 'static,
735        L::Service: Service<Request> + Clone + Send + Sync + 'static,
736        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
737        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
738        <L::Service as Service<Request>>::Future: Send + 'static,
739    {
740        Self {
741            router: self.router.layer(layer),
742            ..self
743        }
744    }
745
746    pub fn build(self) -> HttpServer {
747        HttpServer {
748            options: self.options,
749            user_provider: self.user_provider,
750            shutdown_tx: Mutex::new(None),
751            plugins: self.plugins,
752            router: StdMutex::new(self.router),
753            bind_addr: None,
754            memory_limiter: self.memory_limiter,
755        }
756    }
757}
758
759impl HttpServer {
760    /// Gets the router and adds necessary root routes (health, status, dashboard).
761    pub fn make_app(&self) -> Router {
762        let mut router = {
763            let router = self.router.lock().unwrap();
764            router.clone()
765        };
766
767        router = router
768            .route("/", routing::get(handler::index))
769            .route(
770                "/health",
771                routing::get(handler::health).post(handler::health),
772            )
773            .route(
774                &format!("/{HTTP_API_VERSION}/health"),
775                routing::get(handler::health).post(handler::health),
776            )
777            .route(
778                "/ready",
779                routing::get(handler::health).post(handler::health),
780            );
781
782        router = router.route("/status", routing::get(handler::status));
783
784        #[cfg(feature = "dashboard")]
785        {
786            if !self.options.disable_dashboard {
787                info!("Enable dashboard service at '/dashboard'");
788                // redirect /dashboard to /dashboard/
789                router = router.route(
790                    "/dashboard",
791                    routing::get(|uri: axum::http::uri::Uri| async move {
792                        let path = uri.path();
793                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
794
795                        let new_uri = format!("{}/{}", path, query);
796                        axum::response::Redirect::permanent(&new_uri)
797                    }),
798                );
799
800                // "/dashboard" and "/dashboard/" are two different paths in Axum.
801                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
802                // So we explicitly route "/dashboard/" here.
803                router = router
804                    .route(
805                        "/dashboard/",
806                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
807                    )
808                    .route(
809                        "/dashboard/{*x}",
810                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
811                    );
812            }
813        }
814
815        // Add a layer to collect HTTP metrics for axum.
816        router = router.route_layer(middleware::from_fn(http_metrics_layer));
817
818        router
819    }
820
821    /// Attaches middlewares and debug routes to the router.
822    /// Callers should call this method after [HttpServer::make_app()].
823    pub fn build(&self, router: Router) -> Result<Router> {
824        let timeout_layer = if self.options.timeout != Duration::default() {
825            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
826        } else {
827            info!("HTTP server timeout is disabled");
828            None
829        };
830        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
831            Some(
832                ServiceBuilder::new()
833                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
834            )
835        } else {
836            info!("HTTP server body limit is disabled");
837            None
838        };
839        let cors_layer = if self.options.enable_cors {
840            Some(
841                CorsLayer::new()
842                    .allow_methods([
843                        Method::GET,
844                        Method::POST,
845                        Method::PUT,
846                        Method::DELETE,
847                        Method::HEAD,
848                    ])
849                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
850                        AllowOrigin::from(Any)
851                    } else {
852                        AllowOrigin::from(
853                            self.options
854                                .cors_allowed_origins
855                                .iter()
856                                .map(|s| {
857                                    HeaderValue::from_str(s.as_str())
858                                        .context(InvalidHeaderValueSnafu)
859                                })
860                                .collect::<Result<Vec<HeaderValue>>>()?,
861                        )
862                    })
863                    .allow_headers(Any),
864            )
865        } else {
866            info!("HTTP server cross-origin is disabled");
867            None
868        };
869
870        Ok(router
871            // middlewares
872            .layer(
873                ServiceBuilder::new()
874                    // disable on failure tracing. because printing out isn't very helpful,
875                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
876                    .layer(TraceLayer::new_for_http().on_failure(()))
877                    .option_layer(cors_layer)
878                    .option_layer(timeout_layer)
879                    .option_layer(body_limit_layer)
880                    // memory limit layer - must be before body is consumed
881                    .layer(middleware::from_fn_with_state(
882                        self.memory_limiter.clone(),
883                        memory_limit::memory_limit_middleware,
884                    ))
885                    // auth layer
886                    .layer(middleware::from_fn_with_state(
887                        AuthState::new(self.user_provider.clone()),
888                        authorize::check_http_auth,
889                    ))
890                    .layer(middleware::from_fn(hints::extract_hints))
891                    .layer(middleware::from_fn(client_ip::log_error_with_client_ip))
892                    .layer(middleware::from_fn(
893                        read_preference::extract_read_preference,
894                    )),
895            )
896            // Handlers for debug, we don't expect a timeout.
897            .nest(
898                "/debug",
899                Router::new()
900                    // handler for changing log level dynamically
901                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
902                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
903                    .nest(
904                        "/prof",
905                        Router::new()
906                            .route("/cpu", routing::post(pprof::pprof_handler))
907                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
908                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
909                            .route(
910                                "/mem/activate",
911                                routing::post(mem_prof::activate_heap_prof_handler),
912                            )
913                            .route(
914                                "/mem/deactivate",
915                                routing::post(mem_prof::deactivate_heap_prof_handler),
916                            )
917                            .route(
918                                "/mem/status",
919                                routing::get(mem_prof::heap_prof_status_handler),
920                            ) // jemalloc gdump flag status and toggle
921                            .route(
922                                "/mem/gdump",
923                                routing::get(mem_prof::gdump_status_handler)
924                                    .post(mem_prof::gdump_toggle_handler),
925                            ),
926                    ),
927            ))
928    }
929
930    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
931        Router::new()
932            .route("/metrics", routing::get(handler::metrics))
933            .with_state(metrics_handler)
934    }
935
936    fn route_loki<S>(log_state: LogState) -> Router<S> {
937        Router::new()
938            .route("/api/v1/push", routing::post(loki::loki_ingest))
939            .layer(
940                ServiceBuilder::new()
941                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
942            )
943            .with_state(log_state)
944    }
945
946    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
947        Router::new()
948            // Return fake responsefor HEAD '/' request.
949            .route(
950                "/",
951                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
952            )
953            // Return fake response for Elasticsearch version request.
954            .route("/", routing::get(elasticsearch::handle_get_version))
955            // Return fake response for Elasticsearch license request.
956            .route("/_license", routing::get(elasticsearch::handle_get_license))
957            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
958            .route(
959                "/{index}/_bulk",
960                routing::post(elasticsearch::handle_bulk_api_with_index),
961            )
962            // Return fake response for Elasticsearch ilm request.
963            .route(
964                "/_ilm/policy/{*path}",
965                routing::any((
966                    HttpStatusCode::OK,
967                    elasticsearch::elasticsearch_headers(),
968                    axum::Json(serde_json::json!({})),
969                )),
970            )
971            // Return fake response for Elasticsearch index template request.
972            .route(
973                "/_index_template/{*path}",
974                routing::any((
975                    HttpStatusCode::OK,
976                    elasticsearch::elasticsearch_headers(),
977                    axum::Json(serde_json::json!({})),
978                )),
979            )
980            // Return fake response for Elasticsearch ingest pipeline request.
981            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
982            .route(
983                "/_ingest/{*path}",
984                routing::any((
985                    HttpStatusCode::OK,
986                    elasticsearch::elasticsearch_headers(),
987                    axum::Json(serde_json::json!({})),
988                )),
989            )
990            // Return fake response for Elasticsearch nodes discovery request.
991            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
992            .route(
993                "/_nodes/{*path}",
994                routing::any((
995                    HttpStatusCode::OK,
996                    elasticsearch::elasticsearch_headers(),
997                    axum::Json(serde_json::json!({})),
998                )),
999            )
1000            // Return fake response for Logstash APIs requests.
1001            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1002            .route(
1003                "/logstash/{*path}",
1004                routing::any((
1005                    HttpStatusCode::OK,
1006                    elasticsearch::elasticsearch_headers(),
1007                    axum::Json(serde_json::json!({})),
1008                )),
1009            )
1010            .route(
1011                "/_logstash/{*path}",
1012                routing::any((
1013                    HttpStatusCode::OK,
1014                    elasticsearch::elasticsearch_headers(),
1015                    axum::Json(serde_json::json!({})),
1016                )),
1017            )
1018            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1019            .with_state(log_state)
1020    }
1021
1022    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1023    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1024        Router::new()
1025            .route("/logs", routing::post(event::log_ingester))
1026            .route(
1027                "/pipelines/{pipeline_name}",
1028                routing::get(event::query_pipeline),
1029            )
1030            .route(
1031                "/pipelines/{pipeline_name}",
1032                routing::post(event::add_pipeline),
1033            )
1034            .route(
1035                "/pipelines/{pipeline_name}",
1036                routing::delete(event::delete_pipeline),
1037            )
1038            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1039            .layer(
1040                ServiceBuilder::new()
1041                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1042            )
1043            .with_state(log_state)
1044    }
1045
1046    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1047        Router::new()
1048            .route("/ingest", routing::post(event::log_ingester))
1049            .route(
1050                "/pipelines/{pipeline_name}",
1051                routing::get(event::query_pipeline),
1052            )
1053            .route(
1054                "/pipelines/{pipeline_name}/ddl",
1055                routing::get(event::query_pipeline_ddl),
1056            )
1057            .route(
1058                "/pipelines/{pipeline_name}",
1059                routing::post(event::add_pipeline),
1060            )
1061            .route(
1062                "/pipelines/{pipeline_name}",
1063                routing::delete(event::delete_pipeline),
1064            )
1065            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1066            .layer(
1067                ServiceBuilder::new()
1068                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1069            )
1070            .with_state(log_state)
1071    }
1072
1073    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1074        Router::new()
1075            .route("/sql", routing::get(handler::sql).post(handler::sql))
1076            .route(
1077                "/sql/parse",
1078                routing::get(handler::sql_parse).post(handler::sql_parse),
1079            )
1080            .route(
1081                "/sql/format",
1082                routing::get(handler::sql_format).post(handler::sql_format),
1083            )
1084            .route(
1085                "/promql",
1086                routing::get(handler::promql).post(handler::promql),
1087            )
1088            .with_state(api_state)
1089    }
1090
1091    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1092        Router::new()
1093            .route("/logs", routing::get(logs::logs).post(logs::logs))
1094            .with_state(log_handler)
1095    }
1096
1097    /// Route Prometheus [HTTP API].
1098    ///
1099    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1100    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1101        Router::new()
1102            .route(
1103                "/format_query",
1104                routing::post(format_query).get(format_query),
1105            )
1106            .route("/status/buildinfo", routing::get(build_info_query))
1107            .route("/query", routing::post(instant_query).get(instant_query))
1108            .route("/query_range", routing::post(range_query).get(range_query))
1109            .route("/labels", routing::post(labels_query).get(labels_query))
1110            .route("/series", routing::post(series_query).get(series_query))
1111            .route("/parse_query", routing::post(parse_query).get(parse_query))
1112            .route(
1113                "/label/{label_name}/values",
1114                routing::get(label_values_query),
1115            )
1116            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1117            .with_state(prometheus_handler)
1118    }
1119
1120    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1121    /// called `prom_store`.
1122    ///
1123    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1124    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1125    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1126        Router::new()
1127            .route("/read", routing::post(prom_store::remote_read))
1128            .route("/write", routing::post(prom_store::remote_write))
1129            .with_state(state)
1130    }
1131
1132    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1133        Router::new()
1134            .route("/write", routing::post(influxdb_write_v1))
1135            .route("/api/v2/write", routing::post(influxdb_write_v2))
1136            .layer(
1137                ServiceBuilder::new()
1138                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1139            )
1140            .route("/ping", routing::get(influxdb_ping))
1141            .route("/health", routing::get(influxdb_health))
1142            .with_state(influxdb_handler)
1143    }
1144
1145    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1146        Router::new()
1147            .route("/api/put", routing::post(opentsdb::put))
1148            .with_state(opentsdb_handler)
1149    }
1150
1151    fn route_otlp<S>(
1152        otlp_handler: OpenTelemetryProtocolHandlerRef,
1153        with_metric_engine: bool,
1154    ) -> Router<S> {
1155        Router::new()
1156            .route("/v1/metrics", routing::post(otlp::metrics))
1157            .route("/v1/traces", routing::post(otlp::traces))
1158            .route("/v1/logs", routing::post(otlp::logs))
1159            .layer(
1160                ServiceBuilder::new()
1161                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1162            )
1163            .with_state(OtlpState {
1164                with_metric_engine,
1165                handler: otlp_handler,
1166            })
1167    }
1168
1169    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1170        Router::new()
1171            .route("/config", routing::get(handler::config))
1172            .with_state(state)
1173    }
1174
1175    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1176        Router::new()
1177            .route("/api/services", routing::get(jaeger::handle_get_services))
1178            .route(
1179                "/api/services/{service_name}/operations",
1180                routing::get(jaeger::handle_get_operations_by_service),
1181            )
1182            .route(
1183                "/api/operations",
1184                routing::get(jaeger::handle_get_operations),
1185            )
1186            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1187            .route(
1188                "/api/traces/{trace_id}",
1189                routing::get(jaeger::handle_get_trace),
1190            )
1191            .with_state(handler)
1192    }
1193
1194    #[cfg(feature = "dashboard")]
1195    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1196        use crate::http::dashboard::{add_dashboard, delete_dashboard, list_dashboards};
1197
1198        Router::new()
1199            .route("/", routing::get(list_dashboards))
1200            .route("/{dashboard_name}", routing::post(add_dashboard))
1201            .route("/{dashboard_name}", routing::delete(delete_dashboard))
1202            .layer(
1203                ServiceBuilder::new()
1204                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1205            )
1206            .with_state(DashboardState { handler })
1207    }
1208
1209    #[cfg(not(feature = "dashboard"))]
1210    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1211        Router::new().with_state(DashboardState { handler })
1212    }
1213}
1214
1215pub const HTTP_SERVER: &str = "HTTP_SERVER";
1216
1217#[async_trait]
1218impl Server for HttpServer {
1219    async fn shutdown(&self) -> Result<()> {
1220        let mut shutdown_tx = self.shutdown_tx.lock().await;
1221        if let Some(tx) = shutdown_tx.take()
1222            && tx.send(()).is_err()
1223        {
1224            info!("Receiver dropped, the HTTP server has already exited");
1225        }
1226        info!("Shutdown HTTP server");
1227
1228        Ok(())
1229    }
1230
1231    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1232        let (tx, rx) = oneshot::channel();
1233        let serve = {
1234            let mut shutdown_tx = self.shutdown_tx.lock().await;
1235            ensure!(
1236                shutdown_tx.is_none(),
1237                AlreadyStartedSnafu { server: "HTTP" }
1238            );
1239
1240            let mut app = self.make_app();
1241            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1242                app = configurator
1243                    .configure_http(app, ())
1244                    .await
1245                    .context(OtherSnafu)?;
1246            }
1247            let app = self.build(app)?;
1248            let listener = tokio::net::TcpListener::bind(listening)
1249                .await
1250                .context(AddressBindSnafu { addr: listening })?
1251                .tap_io(|tcp_stream| {
1252                    if let Err(e) = tcp_stream.set_nodelay(true) {
1253                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1254                    }
1255                });
1256            let serve = axum::serve(
1257                listener,
1258                app.into_make_service_with_connect_info::<SocketAddr>(),
1259            );
1260
1261            // FIXME(yingwen): Support keepalive.
1262            // See:
1263            // - https://github.com/tokio-rs/axum/discussions/2939
1264            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1265            // let server = axum::Server::try_bind(&listening)
1266            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1267            //     .tcp_nodelay(true)
1268            //     // Enable TCP keepalive to close the dangling established connections.
1269            //     // It's configured to let the keepalive probes first send after the connection sits
1270            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1271            //     // So the connection will be closed after roughly 1 hour.
1272            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1273            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1274            //     .tcp_keepalive_retries(Some(6))
1275            //     .serve(app.into_make_service());
1276
1277            *shutdown_tx = Some(tx);
1278
1279            serve
1280        };
1281        let listening = serve.local_addr().context(InternalIoSnafu)?;
1282        info!("HTTP server is bound to {}", listening);
1283
1284        common_runtime::spawn_global(async move {
1285            if let Err(e) = serve
1286                .with_graceful_shutdown(rx.map(drop))
1287                .await
1288                .context(InternalIoSnafu)
1289            {
1290                error!(e; "Failed to shutdown http server");
1291            }
1292        });
1293
1294        self.bind_addr = Some(listening);
1295        Ok(())
1296    }
1297
1298    fn name(&self) -> &str {
1299        HTTP_SERVER
1300    }
1301
1302    fn bind_addr(&self) -> Option<SocketAddr> {
1303        self.bind_addr
1304    }
1305
1306    fn as_any(&self) -> &dyn std::any::Any {
1307        self
1308    }
1309}
1310
1311#[cfg(test)]
1312mod test {
1313    use std::future::pending;
1314    use std::io::Cursor;
1315    use std::sync::Arc;
1316
1317    use arrow_ipc::reader::StreamReader;
1318    use arrow_schema::DataType;
1319    use axum::handler::Handler;
1320    use axum::http::StatusCode;
1321    use axum::routing::get;
1322    use common_query::Output;
1323    use common_recordbatch::RecordBatches;
1324    use datafusion_expr::LogicalPlan;
1325    use datatypes::prelude::*;
1326    use datatypes::schema::{ColumnSchema, Schema};
1327    use datatypes::vectors::{StringVector, UInt32Vector};
1328    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1329    use query::parser::PromQuery;
1330    use query::query_engine::DescribeResult;
1331    use session::context::QueryContextRef;
1332    use tokio::sync::mpsc;
1333    use tokio::time::Instant;
1334
1335    use super::*;
1336    use crate::http::test_helpers::TestClient;
1337    use crate::prom_remote_write::validation::validate_label_name;
1338    use crate::query_handler::sql::SqlQueryHandler;
1339
1340    struct DummyInstance {
1341        _tx: mpsc::Sender<(String, Vec<u8>)>,
1342    }
1343
1344    #[async_trait]
1345    impl SqlQueryHandler for DummyInstance {
1346        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1347            unimplemented!()
1348        }
1349
1350        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1351            unimplemented!()
1352        }
1353
1354        async fn do_exec_plan(
1355            &self,
1356            _plan: LogicalPlan,
1357            _query: String,
1358            _query_ctx: QueryContextRef,
1359        ) -> Result<Output> {
1360            unimplemented!()
1361        }
1362
1363        async fn do_describe(
1364            &self,
1365            _stmt: sql::statements::statement::Statement,
1366            _query_ctx: QueryContextRef,
1367        ) -> Result<Option<DescribeResult>> {
1368            unimplemented!()
1369        }
1370
1371        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1372            Ok(true)
1373        }
1374    }
1375
1376    fn timeout() -> DynamicTimeoutLayer {
1377        DynamicTimeoutLayer::new(Duration::from_millis(10))
1378    }
1379
1380    async fn forever() {
1381        pending().await
1382    }
1383
1384    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1385        make_test_app_custom(tx, HttpOptions::default())
1386    }
1387
1388    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1389        let instance = Arc::new(DummyInstance { _tx: tx });
1390        let server = HttpServerBuilder::new(options)
1391            .with_sql_handler(instance.clone())
1392            .build();
1393        server.build(server.make_app()).unwrap().route(
1394            "/test/timeout",
1395            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1396        )
1397    }
1398
1399    #[tokio::test]
1400    pub async fn test_cors() {
1401        // cors is on by default
1402        let (tx, _rx) = mpsc::channel(100);
1403        let app = make_test_app(tx);
1404        let client = TestClient::new(app).await;
1405
1406        let res = client.get("/health").send().await;
1407
1408        assert_eq!(res.status(), StatusCode::OK);
1409        assert_eq!(
1410            res.headers()
1411                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1412                .expect("expect cors header origin"),
1413            "*"
1414        );
1415
1416        let res = client.get("/v1/health").send().await;
1417
1418        assert_eq!(res.status(), StatusCode::OK);
1419        assert_eq!(
1420            res.headers()
1421                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1422                .expect("expect cors header origin"),
1423            "*"
1424        );
1425
1426        let res = client
1427            .options("/health")
1428            .header("Access-Control-Request-Headers", "x-greptime-auth")
1429            .header("Access-Control-Request-Method", "DELETE")
1430            .header("Origin", "https://example.com")
1431            .send()
1432            .await;
1433        assert_eq!(res.status(), StatusCode::OK);
1434        assert_eq!(
1435            res.headers()
1436                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1437                .expect("expect cors header origin"),
1438            "*"
1439        );
1440        assert_eq!(
1441            res.headers()
1442                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1443                .expect("expect cors header headers"),
1444            "*"
1445        );
1446        assert_eq!(
1447            res.headers()
1448                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1449                .expect("expect cors header methods"),
1450            "GET,POST,PUT,DELETE,HEAD"
1451        );
1452    }
1453
1454    #[tokio::test]
1455    pub async fn test_cors_custom_origins() {
1456        // cors is on by default
1457        let (tx, _rx) = mpsc::channel(100);
1458        let origin = "https://example.com";
1459
1460        let options = HttpOptions {
1461            cors_allowed_origins: vec![origin.to_string()],
1462            ..Default::default()
1463        };
1464
1465        let app = make_test_app_custom(tx, options);
1466        let client = TestClient::new(app).await;
1467
1468        let res = client.get("/health").header("Origin", origin).send().await;
1469
1470        assert_eq!(res.status(), StatusCode::OK);
1471        assert_eq!(
1472            res.headers()
1473                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1474                .expect("expect cors header origin"),
1475            origin
1476        );
1477
1478        let res = client
1479            .get("/health")
1480            .header("Origin", "https://notallowed.com")
1481            .send()
1482            .await;
1483
1484        assert_eq!(res.status(), StatusCode::OK);
1485        assert!(
1486            !res.headers()
1487                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1488        );
1489    }
1490
1491    #[tokio::test]
1492    pub async fn test_cors_disabled() {
1493        // cors is on by default
1494        let (tx, _rx) = mpsc::channel(100);
1495
1496        let options = HttpOptions {
1497            enable_cors: false,
1498            ..Default::default()
1499        };
1500
1501        let app = make_test_app_custom(tx, options);
1502        let client = TestClient::new(app).await;
1503
1504        let res = client.get("/health").send().await;
1505
1506        assert_eq!(res.status(), StatusCode::OK);
1507        assert!(
1508            !res.headers()
1509                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1510        );
1511    }
1512
1513    #[test]
1514    fn test_http_options_default() {
1515        let default = HttpOptions::default();
1516        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1517        assert_eq!(Duration::from_secs(0), default.timeout)
1518    }
1519
1520    #[tokio::test]
1521    async fn test_http_server_request_timeout() {
1522        common_telemetry::init_default_ut_logging();
1523
1524        let (tx, _rx) = mpsc::channel(100);
1525        let app = make_test_app(tx);
1526        let client = TestClient::new(app).await;
1527        let res = client.get("/test/timeout").send().await;
1528        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1529
1530        let now = Instant::now();
1531        let res = client
1532            .get("/test/timeout")
1533            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1534            .send()
1535            .await;
1536        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1537        let elapsed = now.elapsed();
1538        assert!(elapsed > Duration::from_millis(15));
1539
1540        tokio::time::timeout(
1541            Duration::from_millis(15),
1542            client
1543                .get("/test/timeout")
1544                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1545                .send(),
1546        )
1547        .await
1548        .unwrap_err();
1549
1550        tokio::time::timeout(
1551            Duration::from_millis(15),
1552            client
1553                .get("/test/timeout")
1554                .header(
1555                    GREPTIME_DB_HEADER_TIMEOUT,
1556                    humantime::format_duration(Duration::default()).to_string(),
1557                )
1558                .send(),
1559        )
1560        .await
1561        .unwrap_err();
1562    }
1563
1564    #[tokio::test]
1565    async fn test_schema_for_empty_response() {
1566        let column_schemas = vec![
1567            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1568            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1569        ];
1570        let schema = Arc::new(Schema::new(column_schemas));
1571
1572        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1573        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1574
1575        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1576        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1577            let json_output = &json_resp.output[0];
1578            if let GreptimeQueryOutput::Records(r) = json_output {
1579                assert_eq!(r.num_rows(), 0);
1580                assert_eq!(r.num_cols(), 2);
1581                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1582                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1583            } else {
1584                panic!("invalid output type");
1585            }
1586        } else {
1587            panic!("invalid format")
1588        }
1589    }
1590
1591    #[tokio::test]
1592    async fn test_recordbatches_conversion() {
1593        let column_schemas = vec![
1594            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1595            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1596        ];
1597        let schema = Arc::new(Schema::new(column_schemas));
1598        let columns: Vec<VectorRef> = vec![
1599            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1600            Arc::new(StringVector::from(vec![
1601                None,
1602                Some("hello"),
1603                Some("greptime"),
1604                None,
1605            ])),
1606        ];
1607        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1608
1609        for format in [
1610            ResponseFormat::GreptimedbV1,
1611            ResponseFormat::InfluxdbV1,
1612            ResponseFormat::Csv(true, true),
1613            ResponseFormat::Table,
1614            ResponseFormat::Arrow,
1615            ResponseFormat::Json,
1616            ResponseFormat::Null,
1617        ] {
1618            let recordbatches =
1619                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1620            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1621            let json_resp = match format {
1622                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1623                ResponseFormat::Csv(with_names, with_types) => {
1624                    CsvResponse::from_output(outputs, with_names, with_types).await
1625                }
1626                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1627                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1628                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1629                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1630                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1631            };
1632
1633            match json_resp {
1634                HttpResponse::GreptimedbV1(resp) => {
1635                    let json_output = &resp.output[0];
1636                    if let GreptimeQueryOutput::Records(r) = json_output {
1637                        assert_eq!(r.num_rows(), 4);
1638                        assert_eq!(r.num_cols(), 2);
1639                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1640                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1641                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1642                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1643                    } else {
1644                        panic!("invalid output type");
1645                    }
1646                }
1647                HttpResponse::InfluxdbV1(resp) => {
1648                    let json_output = &resp.results()[0];
1649                    assert_eq!(json_output.num_rows(), 4);
1650                    assert_eq!(json_output.num_cols(), 2);
1651                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1652                    assert_eq!(
1653                        json_output.series[0].values[0][0],
1654                        serde_json::Value::from(1)
1655                    );
1656                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1657                }
1658                HttpResponse::Csv(resp) => {
1659                    let output = &resp.output()[0];
1660                    if let GreptimeQueryOutput::Records(r) = output {
1661                        assert_eq!(r.num_rows(), 4);
1662                        assert_eq!(r.num_cols(), 2);
1663                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1664                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1665                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1666                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1667                    } else {
1668                        panic!("invalid output type");
1669                    }
1670                }
1671
1672                HttpResponse::Table(resp) => {
1673                    let output = &resp.output()[0];
1674                    if let GreptimeQueryOutput::Records(r) = output {
1675                        assert_eq!(r.num_rows(), 4);
1676                        assert_eq!(r.num_cols(), 2);
1677                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1678                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1679                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1680                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1681                    } else {
1682                        panic!("invalid output type");
1683                    }
1684                }
1685
1686                HttpResponse::Arrow(resp) => {
1687                    let output = resp.data;
1688                    let mut reader = StreamReader::try_new(Cursor::new(output), None)
1689                        .expect("Arrow reader error");
1690                    let schema = reader.schema();
1691                    assert_eq!(schema.fields[0].name(), "numbers");
1692                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1693                    assert_eq!(schema.fields[1].name(), "strings");
1694                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1695
1696                    let rb = reader.next().unwrap().expect("read record batch failed");
1697                    assert_eq!(rb.num_columns(), 2);
1698                    assert_eq!(rb.num_rows(), 4);
1699                }
1700
1701                HttpResponse::Json(resp) => {
1702                    let output = &resp.output()[0];
1703                    if let GreptimeQueryOutput::Records(r) = output {
1704                        assert_eq!(r.num_rows(), 4);
1705                        assert_eq!(r.num_cols(), 2);
1706                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1707                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1708                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1709                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1710                    } else {
1711                        panic!("invalid output type");
1712                    }
1713                }
1714
1715                HttpResponse::Null(resp) => {
1716                    assert_eq!(resp.rows(), 4);
1717                }
1718
1719                HttpResponse::Error(err) => unreachable!("{err:?}"),
1720            }
1721        }
1722    }
1723
1724    #[test]
1725    fn test_response_format_misc() {
1726        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1727        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1728        assert_eq!(
1729            ResponseFormat::parse("csv"),
1730            Some(ResponseFormat::Csv(false, false))
1731        );
1732        assert_eq!(
1733            ResponseFormat::parse("csvwithnames"),
1734            Some(ResponseFormat::Csv(true, false))
1735        );
1736        assert_eq!(
1737            ResponseFormat::parse("csvwithnamesandtypes"),
1738            Some(ResponseFormat::Csv(true, true))
1739        );
1740        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1741        assert_eq!(
1742            ResponseFormat::parse("greptimedb_v1"),
1743            Some(ResponseFormat::GreptimedbV1)
1744        );
1745        assert_eq!(
1746            ResponseFormat::parse("influxdb_v1"),
1747            Some(ResponseFormat::InfluxdbV1)
1748        );
1749        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1750        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1751
1752        // invalid formats
1753        assert_eq!(ResponseFormat::parse("invalid"), None);
1754        assert_eq!(ResponseFormat::parse(""), None);
1755        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1756
1757        // as str
1758        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1759        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1760        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1761        assert_eq!(ResponseFormat::Table.as_str(), "table");
1762        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1763        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1764        assert_eq!(ResponseFormat::Json.as_str(), "json");
1765        assert_eq!(ResponseFormat::Null.as_str(), "null");
1766        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1767    }
1768
1769    #[test]
1770    fn test_decode_label_name_strict() {
1771        let strict = PromValidationMode::Strict;
1772
1773        // Valid Prometheus label names
1774        assert!(strict.decode_label_name(b"__name__").is_ok());
1775        assert!(strict.decode_label_name(b"job").is_ok());
1776        assert!(strict.decode_label_name(b"instance").is_ok());
1777        assert!(strict.decode_label_name(b"_private").is_ok());
1778        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1779        assert!(strict.decode_label_name(b"abc123").is_ok());
1780        assert!(strict.decode_label_name(b"A").is_ok());
1781        assert!(strict.decode_label_name(b"_").is_ok());
1782
1783        // Invalid: starts with digit
1784        assert!(strict.decode_label_name(b"0abc").is_err());
1785        assert!(strict.decode_label_name(b"123").is_err());
1786
1787        // Invalid: contains special characters
1788        assert!(strict.decode_label_name(b"label-name").is_err());
1789        assert!(strict.decode_label_name(b"label.name").is_err());
1790        assert!(strict.decode_label_name(b"label name").is_err());
1791        assert!(strict.decode_label_name(b"label/name").is_err());
1792
1793        // Invalid: empty
1794        assert!(strict.decode_label_name(b"").is_err());
1795
1796        // Invalid: non-ASCII UTF-8
1797        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1798
1799        // Invalid UTF-8 bytes should fail
1800        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1801    }
1802
1803    #[test]
1804    fn test_decode_label_name_lossy() {
1805        let lossy = PromValidationMode::Lossy;
1806
1807        // Label name validation is always enforced.
1808        assert!(lossy.decode_label_name(b"__name__").is_ok());
1809        assert!(lossy.decode_label_name(b"label-name").is_err());
1810        assert!(lossy.decode_label_name(b"0abc").is_err());
1811
1812        // Invalid UTF-8 bytes fail the label-name byte check.
1813        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1814    }
1815
1816    #[test]
1817    fn test_decode_label_name_unchecked() {
1818        let unchecked = PromValidationMode::Unchecked;
1819
1820        // Label name validation is always enforced.
1821        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1822        assert!(unchecked.decode_label_name(b"label-name").is_err());
1823        assert!(unchecked.decode_label_name(b"0abc").is_err());
1824    }
1825
1826    #[test]
1827    fn test_is_valid_prom_label_name_bytes() {
1828        assert!(validate_label_name(b"__name__"));
1829        assert!(validate_label_name(b"job"));
1830        assert!(validate_label_name(b"_"));
1831        assert!(validate_label_name(b"A"));
1832        assert!(validate_label_name(b"abc123"));
1833        assert!(validate_label_name(b"_leading_underscore"));
1834
1835        assert!(!validate_label_name(b""));
1836        assert!(!validate_label_name(b"0starts_with_digit"));
1837        assert!(!validate_label_name(b"has-dash"));
1838        assert!(!validate_label_name(b"has.dot"));
1839        assert!(!validate_label_name(b"has space"));
1840        assert!(!validate_label_name(&[0xff, 0xfe]));
1841    }
1842}