Skip to main content

servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::{Arc, Mutex as StdMutex};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use snafu::{ResultExt, ensure};
44use tokio::sync::Mutex;
45use tokio::sync::oneshot::{self, Sender};
46use tonic::codegen::Service;
47use tower::{Layer, ServiceBuilder};
48use tower_http::compression::CompressionLayer;
49use tower_http::cors::{AllowOrigin, Any, CorsLayer};
50use tower_http::decompression::RequestDecompressionLayer;
51use tower_http::trace::TraceLayer;
52
53use self::authorize::AuthState;
54use self::result::table_result::TableResponse;
55use crate::configurator::HttpConfiguratorRef;
56use crate::elasticsearch;
57use crate::error::{
58    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
59    OtherSnafu, Result,
60};
61use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
62use crate::http::otlp::OtlpState;
63use crate::http::prom_store::PromStoreState;
64use crate::http::prometheus::{
65    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
66    range_query, series_query,
67};
68use crate::http::result::arrow_result::ArrowResponse;
69use crate::http::result::csv_result::CsvResponse;
70use crate::http::result::error_result::ErrorResponse;
71use crate::http::result::greptime_result_v1::GreptimedbV1Response;
72use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
73use crate::http::result::json_result::JsonResponse;
74use crate::http::result::null_result::NullResponse;
75use crate::interceptor::LogIngestInterceptorRef;
76use crate::metrics::http_metrics_layer;
77use crate::metrics_handler::MetricsHandler;
78use crate::pending_rows_batcher::PendingRowsBatcher;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    DashboardHandlerRef, InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_memory_limiter::ServerMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod client_ip;
117use crate::prom_remote_write::validation::PromValidationMode;
118mod hints;
119mod read_preference;
120#[cfg(any(test, feature = "testing"))]
121pub mod test_helpers;
122
123pub const HTTP_API_VERSION: &str = "v1";
124pub const HTTP_API_PREFIX: &str = "/v1/";
125/// Default http body limit (64M).
126const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
127
128/// Authorization header
129pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
130
131// TODO(fys): This is a temporary workaround, it will be improved later
132pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
133
134#[derive(Default)]
135pub struct HttpServer {
136    router: StdMutex<Router>,
137    shutdown_tx: Mutex<Option<Sender<()>>>,
138    user_provider: Option<UserProviderRef>,
139    memory_limiter: ServerMemoryLimiter,
140
141    // plugins
142    plugins: Plugins,
143
144    // server configs
145    options: HttpOptions,
146    bind_addr: Option<SocketAddr>,
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
150#[serde(default)]
151pub struct HttpOptions {
152    pub addr: String,
153
154    #[serde(with = "humantime_serde")]
155    pub timeout: Duration,
156
157    #[serde(skip)]
158    pub disable_dashboard: bool,
159
160    pub body_limit: ReadableSize,
161
162    /// Validation mode while decoding Prometheus remote write requests.
163    pub prom_validation_mode: PromValidationMode,
164
165    pub cors_allowed_origins: Vec<String>,
166
167    pub enable_cors: bool,
168}
169
170impl Default for HttpOptions {
171    fn default() -> Self {
172        Self {
173            addr: "127.0.0.1:4000".to_string(),
174            timeout: Duration::from_secs(0),
175            disable_dashboard: false,
176            body_limit: DEFAULT_BODY_LIMIT,
177            cors_allowed_origins: Vec::new(),
178            enable_cors: true,
179            prom_validation_mode: PromValidationMode::Strict,
180        }
181    }
182}
183
184#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
185pub struct ColumnSchema {
186    name: String,
187    data_type: String,
188}
189
190impl ColumnSchema {
191    pub fn new(name: String, data_type: String) -> ColumnSchema {
192        ColumnSchema { name, data_type }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
197pub struct OutputSchema {
198    column_schemas: Vec<ColumnSchema>,
199}
200
201impl OutputSchema {
202    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
203        OutputSchema {
204            column_schemas: columns,
205        }
206    }
207}
208
209impl From<SchemaRef> for OutputSchema {
210    fn from(schema: SchemaRef) -> OutputSchema {
211        OutputSchema {
212            column_schemas: schema
213                .column_schemas()
214                .iter()
215                .map(|cs| ColumnSchema {
216                    name: cs.name.clone(),
217                    data_type: cs.data_type.name(),
218                })
219                .collect(),
220        }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
225pub struct HttpRecordsOutput {
226    schema: OutputSchema,
227    rows: Vec<Vec<Value>>,
228    // total_rows is equal to rows.len() in most cases,
229    // the Dashboard query result may be truncated, so we need to return the total_rows.
230    #[serde(default)]
231    total_rows: usize,
232
233    // plan level execution metrics
234    #[serde(skip_serializing_if = "HashMap::is_empty")]
235    #[serde(default)]
236    metrics: HashMap<String, Value>,
237}
238
239impl HttpRecordsOutput {
240    pub fn num_rows(&self) -> usize {
241        self.rows.len()
242    }
243
244    pub fn num_cols(&self) -> usize {
245        self.schema.column_schemas.len()
246    }
247
248    pub fn schema(&self) -> &OutputSchema {
249        &self.schema
250    }
251
252    pub fn rows(&self) -> &Vec<Vec<Value>> {
253        &self.rows
254    }
255}
256
257impl HttpRecordsOutput {
258    pub fn try_new(
259        schema: SchemaRef,
260        recordbatches: Vec<RecordBatch>,
261    ) -> std::result::Result<HttpRecordsOutput, Error> {
262        if recordbatches.is_empty() {
263            Ok(HttpRecordsOutput {
264                schema: OutputSchema::from(schema),
265                rows: vec![],
266                total_rows: 0,
267                metrics: Default::default(),
268            })
269        } else {
270            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
271            let mut rows = Vec::with_capacity(num_rows);
272
273            for recordbatch in recordbatches {
274                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
275                writer.write(recordbatch, &mut rows)?;
276            }
277
278            Ok(HttpRecordsOutput {
279                schema: OutputSchema::from(schema),
280                total_rows: rows.len(),
281                rows,
282                metrics: Default::default(),
283            })
284        }
285    }
286}
287
288#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
289#[serde(rename_all = "lowercase")]
290pub enum GreptimeQueryOutput {
291    AffectedRows(usize),
292    Records(HttpRecordsOutput),
293}
294
295/// It allows the results of SQL queries to be presented in different formats.
296#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
297pub enum ResponseFormat {
298    Arrow,
299    // (with_names, with_types)
300    Csv(bool, bool),
301    Table,
302    #[default]
303    GreptimedbV1,
304    InfluxdbV1,
305    Json,
306    Null,
307}
308
309impl ResponseFormat {
310    pub fn parse(s: &str) -> Option<Self> {
311        match s {
312            "arrow" => Some(ResponseFormat::Arrow),
313            "csv" => Some(ResponseFormat::Csv(false, false)),
314            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
315            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
316            "table" => Some(ResponseFormat::Table),
317            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
318            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
319            "json" => Some(ResponseFormat::Json),
320            "null" => Some(ResponseFormat::Null),
321            _ => None,
322        }
323    }
324
325    pub fn as_str(&self) -> &'static str {
326        match self {
327            ResponseFormat::Arrow => "arrow",
328            ResponseFormat::Csv(_, _) => "csv",
329            ResponseFormat::Table => "table",
330            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
331            ResponseFormat::InfluxdbV1 => "influxdb_v1",
332            ResponseFormat::Json => "json",
333            ResponseFormat::Null => "null",
334        }
335    }
336}
337
338#[derive(Debug, Clone, Copy, PartialEq, Eq)]
339pub enum Epoch {
340    Nanosecond,
341    Microsecond,
342    Millisecond,
343    Second,
344}
345
346impl Epoch {
347    pub fn parse(s: &str) -> Option<Epoch> {
348        // Both u and µ indicate microseconds.
349        // epoch = [ns,u,µ,ms,s],
350        // For details, see the Influxdb documents.
351        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
352        match s {
353            "ns" => Some(Epoch::Nanosecond),
354            "u" | "µ" => Some(Epoch::Microsecond),
355            "ms" => Some(Epoch::Millisecond),
356            "s" => Some(Epoch::Second),
357            _ => None, // just returns None for other cases
358        }
359    }
360
361    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
362        match self {
363            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
364            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
365            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
366            Epoch::Second => ts.convert_to(TimeUnit::Second),
367        }
368    }
369}
370
371impl Display for Epoch {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match self {
374            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
375            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
376            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
377            Epoch::Second => write!(f, "Epoch::Second"),
378        }
379    }
380}
381
382#[derive(Serialize, Deserialize, Debug)]
383pub enum HttpResponse {
384    Arrow(ArrowResponse),
385    Csv(CsvResponse),
386    Table(TableResponse),
387    Error(ErrorResponse),
388    GreptimedbV1(GreptimedbV1Response),
389    InfluxdbV1(InfluxdbV1Response),
390    Json(JsonResponse),
391    Null(NullResponse),
392}
393
394impl HttpResponse {
395    pub fn with_execution_time(self, execution_time: u64) -> Self {
396        match self {
397            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
398            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
399            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
400            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
401            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
402            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
403            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
404            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
405        }
406    }
407
408    pub fn with_limit(self, limit: usize) -> Self {
409        match self {
410            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
411            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
412            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
413            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
414            _ => self,
415        }
416    }
417}
418
419pub fn process_with_limit(
420    mut outputs: Vec<GreptimeQueryOutput>,
421    limit: usize,
422) -> Vec<GreptimeQueryOutput> {
423    outputs
424        .drain(..)
425        .map(|data| match data {
426            GreptimeQueryOutput::Records(mut records) => {
427                if records.rows.len() > limit {
428                    records.rows.truncate(limit);
429                    records.total_rows = limit;
430                }
431                GreptimeQueryOutput::Records(records)
432            }
433            _ => data,
434        })
435        .collect()
436}
437
438impl IntoResponse for HttpResponse {
439    fn into_response(self) -> Response {
440        match self {
441            HttpResponse::Arrow(resp) => resp.into_response(),
442            HttpResponse::Csv(resp) => resp.into_response(),
443            HttpResponse::Table(resp) => resp.into_response(),
444            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
445            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
446            HttpResponse::Json(resp) => resp.into_response(),
447            HttpResponse::Null(resp) => resp.into_response(),
448            HttpResponse::Error(resp) => resp.into_response(),
449        }
450    }
451}
452
453impl From<ArrowResponse> for HttpResponse {
454    fn from(value: ArrowResponse) -> Self {
455        HttpResponse::Arrow(value)
456    }
457}
458
459impl From<CsvResponse> for HttpResponse {
460    fn from(value: CsvResponse) -> Self {
461        HttpResponse::Csv(value)
462    }
463}
464
465impl From<TableResponse> for HttpResponse {
466    fn from(value: TableResponse) -> Self {
467        HttpResponse::Table(value)
468    }
469}
470
471impl From<ErrorResponse> for HttpResponse {
472    fn from(value: ErrorResponse) -> Self {
473        HttpResponse::Error(value)
474    }
475}
476
477impl From<GreptimedbV1Response> for HttpResponse {
478    fn from(value: GreptimedbV1Response) -> Self {
479        HttpResponse::GreptimedbV1(value)
480    }
481}
482
483impl From<InfluxdbV1Response> for HttpResponse {
484    fn from(value: InfluxdbV1Response) -> Self {
485        HttpResponse::InfluxdbV1(value)
486    }
487}
488
489impl From<JsonResponse> for HttpResponse {
490    fn from(value: JsonResponse) -> Self {
491        HttpResponse::Json(value)
492    }
493}
494
495impl From<NullResponse> for HttpResponse {
496    fn from(value: NullResponse) -> Self {
497        HttpResponse::Null(value)
498    }
499}
500
501#[derive(Clone)]
502pub struct ApiState {
503    pub sql_handler: ServerSqlQueryHandlerRef,
504}
505
506#[derive(Clone)]
507pub struct GreptimeOptionsConfigState {
508    pub greptime_config_options: String,
509}
510
511#[derive(Clone)]
512pub struct DashboardState {
513    pub handler: DashboardHandlerRef,
514}
515
516pub struct HttpServerBuilder {
517    options: HttpOptions,
518    plugins: Plugins,
519    user_provider: Option<UserProviderRef>,
520    router: Router,
521    memory_limiter: ServerMemoryLimiter,
522}
523
524impl HttpServerBuilder {
525    pub fn new(options: HttpOptions) -> Self {
526        Self {
527            options,
528            plugins: Plugins::default(),
529            user_provider: None,
530            router: Router::new(),
531            memory_limiter: ServerMemoryLimiter::default(),
532        }
533    }
534
535    /// Set a global memory limiter for all server protocols.
536    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
537        self.memory_limiter = limiter;
538        self
539    }
540
541    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
542        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
543
544        Self {
545            router: self
546                .router
547                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
548            ..self
549        }
550    }
551
552    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
553        let logs_router = HttpServer::route_logs(logs_handler);
554
555        Self {
556            router: self
557                .router
558                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
559            ..self
560        }
561    }
562
563    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
564        Self {
565            router: self.router.nest(
566                &format!("/{HTTP_API_VERSION}/opentsdb"),
567                HttpServer::route_opentsdb(handler),
568            ),
569            ..self
570        }
571    }
572
573    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
574        Self {
575            router: self.router.nest(
576                &format!("/{HTTP_API_VERSION}/influxdb"),
577                HttpServer::route_influxdb(handler),
578            ),
579            ..self
580        }
581    }
582
583    pub fn with_prom_handler(
584        self,
585        handler: PromStoreProtocolHandlerRef,
586        pipeline_handler: Option<PipelineHandlerRef>,
587        prom_store_with_metric_engine: bool,
588        prom_validation_mode: PromValidationMode,
589        pending_rows_batcher: Option<Arc<PendingRowsBatcher>>,
590    ) -> Self {
591        let state = PromStoreState {
592            prom_store_handler: handler,
593            pipeline_handler,
594            prom_store_with_metric_engine,
595            prom_validation_mode,
596            pending_rows_batcher,
597        };
598
599        Self {
600            router: self.router.nest(
601                &format!("/{HTTP_API_VERSION}/prometheus"),
602                HttpServer::route_prom(state),
603            ),
604            ..self
605        }
606    }
607
608    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
609        Self {
610            router: self.router.nest(
611                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
612                HttpServer::route_prometheus(handler),
613            ),
614            ..self
615        }
616    }
617
618    pub fn with_otlp_handler(
619        self,
620        handler: OpenTelemetryProtocolHandlerRef,
621        with_metric_engine: bool,
622    ) -> Self {
623        Self {
624            router: self.router.nest(
625                &format!("/{HTTP_API_VERSION}/otlp"),
626                HttpServer::route_otlp(handler, with_metric_engine),
627            ),
628            ..self
629        }
630    }
631
632    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
633        Self {
634            user_provider: Some(user_provider),
635            ..self
636        }
637    }
638
639    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
640        Self {
641            router: self.router.merge(HttpServer::route_metrics(handler)),
642            ..self
643        }
644    }
645
646    pub fn with_log_ingest_handler(
647        self,
648        handler: PipelineHandlerRef,
649        validator: Option<LogValidatorRef>,
650        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
651    ) -> Self {
652        let log_state = LogState {
653            log_handler: handler,
654            log_validator: validator,
655            ingest_interceptor,
656        };
657
658        let router = self.router.nest(
659            &format!("/{HTTP_API_VERSION}"),
660            HttpServer::route_pipelines(log_state.clone()),
661        );
662        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
663        let router = router.nest(
664            &format!("/{HTTP_API_VERSION}/events"),
665            #[allow(deprecated)]
666            HttpServer::route_log_deprecated(log_state.clone()),
667        );
668
669        let router = router.nest(
670            &format!("/{HTTP_API_VERSION}/loki"),
671            HttpServer::route_loki(log_state.clone()),
672        );
673
674        let router = router.nest(
675            &format!("/{HTTP_API_VERSION}/elasticsearch"),
676            HttpServer::route_elasticsearch(log_state.clone()),
677        );
678
679        let router = router.nest(
680            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
681            Router::new()
682                .route("/", routing::get(elasticsearch::handle_get_version))
683                .with_state(log_state),
684        );
685
686        Self { router, ..self }
687    }
688
689    pub fn with_plugins(self, plugins: Plugins) -> Self {
690        Self { plugins, ..self }
691    }
692
693    pub fn with_greptime_config_options(self, opts: String) -> Self {
694        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
695            greptime_config_options: opts,
696        });
697
698        Self {
699            router: self.router.merge(config_router),
700            ..self
701        }
702    }
703
704    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
705        Self {
706            router: self.router.nest(
707                &format!("/{HTTP_API_VERSION}/jaeger"),
708                HttpServer::route_jaeger(handler),
709            ),
710            ..self
711        }
712    }
713
714    pub fn with_dashboard_handler(self, handler: DashboardHandlerRef) -> Self {
715        Self {
716            router: self.router.nest(
717                &format!("/{HTTP_API_VERSION}/dashboards"),
718                HttpServer::route_dashboard(handler),
719            ),
720            ..self
721        }
722    }
723
724    pub fn with_extra_router(self, router: Router) -> Self {
725        Self {
726            router: self.router.merge(router),
727            ..self
728        }
729    }
730
731    pub fn add_layer<L>(self, layer: L) -> Self
732    where
733        L: Layer<Route> + Clone + Send + Sync + 'static,
734        L::Service: Service<Request> + Clone + Send + Sync + 'static,
735        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
736        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
737        <L::Service as Service<Request>>::Future: Send + 'static,
738    {
739        Self {
740            router: self.router.layer(layer),
741            ..self
742        }
743    }
744
745    pub fn build(self) -> HttpServer {
746        HttpServer {
747            options: self.options,
748            user_provider: self.user_provider,
749            shutdown_tx: Mutex::new(None),
750            plugins: self.plugins,
751            router: StdMutex::new(self.router),
752            bind_addr: None,
753            memory_limiter: self.memory_limiter,
754        }
755    }
756}
757
758impl HttpServer {
759    /// Gets the router and adds necessary root routes (health, status, dashboard).
760    pub fn make_app(&self) -> Router {
761        let mut router = {
762            let router = self.router.lock().unwrap();
763            router.clone()
764        };
765
766        router = router
767            .route(
768                "/health",
769                routing::get(handler::health).post(handler::health),
770            )
771            .route(
772                &format!("/{HTTP_API_VERSION}/health"),
773                routing::get(handler::health).post(handler::health),
774            )
775            .route(
776                "/ready",
777                routing::get(handler::health).post(handler::health),
778            );
779
780        router = router.route("/status", routing::get(handler::status));
781
782        #[cfg(feature = "dashboard")]
783        {
784            if !self.options.disable_dashboard {
785                info!("Enable dashboard service at '/dashboard'");
786                // redirect /dashboard to /dashboard/
787                router = router.route(
788                    "/dashboard",
789                    routing::get(|uri: axum::http::uri::Uri| async move {
790                        let path = uri.path();
791                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
792
793                        let new_uri = format!("{}/{}", path, query);
794                        axum::response::Redirect::permanent(&new_uri)
795                    }),
796                );
797
798                // "/dashboard" and "/dashboard/" are two different paths in Axum.
799                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
800                // So we explicitly route "/dashboard/" here.
801                router = router
802                    .route(
803                        "/dashboard/",
804                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
805                    )
806                    .route(
807                        "/dashboard/{*x}",
808                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
809                    );
810            }
811        }
812
813        // Add a layer to collect HTTP metrics for axum.
814        router = router.route_layer(middleware::from_fn(http_metrics_layer));
815
816        router
817    }
818
819    /// Attaches middlewares and debug routes to the router.
820    /// Callers should call this method after [HttpServer::make_app()].
821    pub fn build(&self, router: Router) -> Result<Router> {
822        let timeout_layer = if self.options.timeout != Duration::default() {
823            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
824        } else {
825            info!("HTTP server timeout is disabled");
826            None
827        };
828        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
829            Some(
830                ServiceBuilder::new()
831                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
832            )
833        } else {
834            info!("HTTP server body limit is disabled");
835            None
836        };
837        let cors_layer = if self.options.enable_cors {
838            Some(
839                CorsLayer::new()
840                    .allow_methods([
841                        Method::GET,
842                        Method::POST,
843                        Method::PUT,
844                        Method::DELETE,
845                        Method::HEAD,
846                    ])
847                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
848                        AllowOrigin::from(Any)
849                    } else {
850                        AllowOrigin::from(
851                            self.options
852                                .cors_allowed_origins
853                                .iter()
854                                .map(|s| {
855                                    HeaderValue::from_str(s.as_str())
856                                        .context(InvalidHeaderValueSnafu)
857                                })
858                                .collect::<Result<Vec<HeaderValue>>>()?,
859                        )
860                    })
861                    .allow_headers(Any),
862            )
863        } else {
864            info!("HTTP server cross-origin is disabled");
865            None
866        };
867
868        Ok(router
869            // middlewares
870            .layer(
871                ServiceBuilder::new()
872                    // disable on failure tracing. because printing out isn't very helpful,
873                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
874                    .layer(TraceLayer::new_for_http().on_failure(()))
875                    .option_layer(cors_layer)
876                    .option_layer(timeout_layer)
877                    .option_layer(body_limit_layer)
878                    // memory limit layer - must be before body is consumed
879                    .layer(middleware::from_fn_with_state(
880                        self.memory_limiter.clone(),
881                        memory_limit::memory_limit_middleware,
882                    ))
883                    // auth layer
884                    .layer(middleware::from_fn_with_state(
885                        AuthState::new(self.user_provider.clone()),
886                        authorize::check_http_auth,
887                    ))
888                    .layer(middleware::from_fn(hints::extract_hints))
889                    .layer(middleware::from_fn(client_ip::log_error_with_client_ip))
890                    .layer(middleware::from_fn(
891                        read_preference::extract_read_preference,
892                    )),
893            )
894            // Handlers for debug, we don't expect a timeout.
895            .nest(
896                "/debug",
897                Router::new()
898                    // handler for changing log level dynamically
899                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
900                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
901                    .nest(
902                        "/prof",
903                        Router::new()
904                            .route("/cpu", routing::post(pprof::pprof_handler))
905                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
906                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
907                            .route(
908                                "/mem/activate",
909                                routing::post(mem_prof::activate_heap_prof_handler),
910                            )
911                            .route(
912                                "/mem/deactivate",
913                                routing::post(mem_prof::deactivate_heap_prof_handler),
914                            )
915                            .route(
916                                "/mem/status",
917                                routing::get(mem_prof::heap_prof_status_handler),
918                            ) // jemalloc gdump flag status and toggle
919                            .route(
920                                "/mem/gdump",
921                                routing::get(mem_prof::gdump_status_handler)
922                                    .post(mem_prof::gdump_toggle_handler),
923                            ),
924                    ),
925            ))
926    }
927
928    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
929        Router::new()
930            .route("/metrics", routing::get(handler::metrics))
931            .with_state(metrics_handler)
932    }
933
934    fn route_loki<S>(log_state: LogState) -> Router<S> {
935        Router::new()
936            .route("/api/v1/push", routing::post(loki::loki_ingest))
937            .layer(
938                ServiceBuilder::new()
939                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
940            )
941            .with_state(log_state)
942    }
943
944    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
945        Router::new()
946            // Return fake responsefor HEAD '/' request.
947            .route(
948                "/",
949                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
950            )
951            // Return fake response for Elasticsearch version request.
952            .route("/", routing::get(elasticsearch::handle_get_version))
953            // Return fake response for Elasticsearch license request.
954            .route("/_license", routing::get(elasticsearch::handle_get_license))
955            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
956            .route(
957                "/{index}/_bulk",
958                routing::post(elasticsearch::handle_bulk_api_with_index),
959            )
960            // Return fake response for Elasticsearch ilm request.
961            .route(
962                "/_ilm/policy/{*path}",
963                routing::any((
964                    HttpStatusCode::OK,
965                    elasticsearch::elasticsearch_headers(),
966                    axum::Json(serde_json::json!({})),
967                )),
968            )
969            // Return fake response for Elasticsearch index template request.
970            .route(
971                "/_index_template/{*path}",
972                routing::any((
973                    HttpStatusCode::OK,
974                    elasticsearch::elasticsearch_headers(),
975                    axum::Json(serde_json::json!({})),
976                )),
977            )
978            // Return fake response for Elasticsearch ingest pipeline request.
979            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
980            .route(
981                "/_ingest/{*path}",
982                routing::any((
983                    HttpStatusCode::OK,
984                    elasticsearch::elasticsearch_headers(),
985                    axum::Json(serde_json::json!({})),
986                )),
987            )
988            // Return fake response for Elasticsearch nodes discovery request.
989            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
990            .route(
991                "/_nodes/{*path}",
992                routing::any((
993                    HttpStatusCode::OK,
994                    elasticsearch::elasticsearch_headers(),
995                    axum::Json(serde_json::json!({})),
996                )),
997            )
998            // Return fake response for Logstash APIs requests.
999            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1000            .route(
1001                "/logstash/{*path}",
1002                routing::any((
1003                    HttpStatusCode::OK,
1004                    elasticsearch::elasticsearch_headers(),
1005                    axum::Json(serde_json::json!({})),
1006                )),
1007            )
1008            .route(
1009                "/_logstash/{*path}",
1010                routing::any((
1011                    HttpStatusCode::OK,
1012                    elasticsearch::elasticsearch_headers(),
1013                    axum::Json(serde_json::json!({})),
1014                )),
1015            )
1016            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1017            .with_state(log_state)
1018    }
1019
1020    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1021    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1022        Router::new()
1023            .route("/logs", routing::post(event::log_ingester))
1024            .route(
1025                "/pipelines/{pipeline_name}",
1026                routing::get(event::query_pipeline),
1027            )
1028            .route(
1029                "/pipelines/{pipeline_name}",
1030                routing::post(event::add_pipeline),
1031            )
1032            .route(
1033                "/pipelines/{pipeline_name}",
1034                routing::delete(event::delete_pipeline),
1035            )
1036            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1037            .layer(
1038                ServiceBuilder::new()
1039                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1040            )
1041            .with_state(log_state)
1042    }
1043
1044    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1045        Router::new()
1046            .route("/ingest", routing::post(event::log_ingester))
1047            .route(
1048                "/pipelines/{pipeline_name}",
1049                routing::get(event::query_pipeline),
1050            )
1051            .route(
1052                "/pipelines/{pipeline_name}/ddl",
1053                routing::get(event::query_pipeline_ddl),
1054            )
1055            .route(
1056                "/pipelines/{pipeline_name}",
1057                routing::post(event::add_pipeline),
1058            )
1059            .route(
1060                "/pipelines/{pipeline_name}",
1061                routing::delete(event::delete_pipeline),
1062            )
1063            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1064            .layer(
1065                ServiceBuilder::new()
1066                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1067            )
1068            .with_state(log_state)
1069    }
1070
1071    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1072        Router::new()
1073            .route("/sql", routing::get(handler::sql).post(handler::sql))
1074            .route(
1075                "/sql/parse",
1076                routing::get(handler::sql_parse).post(handler::sql_parse),
1077            )
1078            .route(
1079                "/sql/format",
1080                routing::get(handler::sql_format).post(handler::sql_format),
1081            )
1082            .route(
1083                "/promql",
1084                routing::get(handler::promql).post(handler::promql),
1085            )
1086            .with_state(api_state)
1087    }
1088
1089    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1090        Router::new()
1091            .route("/logs", routing::get(logs::logs).post(logs::logs))
1092            .with_state(log_handler)
1093    }
1094
1095    /// Route Prometheus [HTTP API].
1096    ///
1097    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1098    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1099        Router::new()
1100            .route(
1101                "/format_query",
1102                routing::post(format_query).get(format_query),
1103            )
1104            .route("/status/buildinfo", routing::get(build_info_query))
1105            .route("/query", routing::post(instant_query).get(instant_query))
1106            .route("/query_range", routing::post(range_query).get(range_query))
1107            .route("/labels", routing::post(labels_query).get(labels_query))
1108            .route("/series", routing::post(series_query).get(series_query))
1109            .route("/parse_query", routing::post(parse_query).get(parse_query))
1110            .route(
1111                "/label/{label_name}/values",
1112                routing::get(label_values_query),
1113            )
1114            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1115            .with_state(prometheus_handler)
1116    }
1117
1118    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1119    /// called `prom_store`.
1120    ///
1121    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1122    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1123    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1124        Router::new()
1125            .route("/read", routing::post(prom_store::remote_read))
1126            .route("/write", routing::post(prom_store::remote_write))
1127            .with_state(state)
1128    }
1129
1130    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1131        Router::new()
1132            .route("/write", routing::post(influxdb_write_v1))
1133            .route("/api/v2/write", routing::post(influxdb_write_v2))
1134            .layer(
1135                ServiceBuilder::new()
1136                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1137            )
1138            .route("/ping", routing::get(influxdb_ping))
1139            .route("/health", routing::get(influxdb_health))
1140            .with_state(influxdb_handler)
1141    }
1142
1143    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1144        Router::new()
1145            .route("/api/put", routing::post(opentsdb::put))
1146            .with_state(opentsdb_handler)
1147    }
1148
1149    fn route_otlp<S>(
1150        otlp_handler: OpenTelemetryProtocolHandlerRef,
1151        with_metric_engine: bool,
1152    ) -> Router<S> {
1153        Router::new()
1154            .route("/v1/metrics", routing::post(otlp::metrics))
1155            .route("/v1/traces", routing::post(otlp::traces))
1156            .route("/v1/logs", routing::post(otlp::logs))
1157            .layer(
1158                ServiceBuilder::new()
1159                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1160            )
1161            .with_state(OtlpState {
1162                with_metric_engine,
1163                handler: otlp_handler,
1164            })
1165    }
1166
1167    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1168        Router::new()
1169            .route("/config", routing::get(handler::config))
1170            .with_state(state)
1171    }
1172
1173    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1174        Router::new()
1175            .route("/api/services", routing::get(jaeger::handle_get_services))
1176            .route(
1177                "/api/services/{service_name}/operations",
1178                routing::get(jaeger::handle_get_operations_by_service),
1179            )
1180            .route(
1181                "/api/operations",
1182                routing::get(jaeger::handle_get_operations),
1183            )
1184            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1185            .route(
1186                "/api/traces/{trace_id}",
1187                routing::get(jaeger::handle_get_trace),
1188            )
1189            .with_state(handler)
1190    }
1191
1192    #[cfg(feature = "dashboard")]
1193    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1194        use crate::http::dashboard::{add_dashboard, delete_dashboard, list_dashboards};
1195
1196        Router::new()
1197            .route("/", routing::get(list_dashboards))
1198            .route("/{dashboard_name}", routing::post(add_dashboard))
1199            .route("/{dashboard_name}", routing::delete(delete_dashboard))
1200            .layer(
1201                ServiceBuilder::new()
1202                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1203            )
1204            .with_state(DashboardState { handler })
1205    }
1206
1207    #[cfg(not(feature = "dashboard"))]
1208    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1209        Router::new().with_state(DashboardState { handler })
1210    }
1211}
1212
1213pub const HTTP_SERVER: &str = "HTTP_SERVER";
1214
1215#[async_trait]
1216impl Server for HttpServer {
1217    async fn shutdown(&self) -> Result<()> {
1218        let mut shutdown_tx = self.shutdown_tx.lock().await;
1219        if let Some(tx) = shutdown_tx.take()
1220            && tx.send(()).is_err()
1221        {
1222            info!("Receiver dropped, the HTTP server has already exited");
1223        }
1224        info!("Shutdown HTTP server");
1225
1226        Ok(())
1227    }
1228
1229    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1230        let (tx, rx) = oneshot::channel();
1231        let serve = {
1232            let mut shutdown_tx = self.shutdown_tx.lock().await;
1233            ensure!(
1234                shutdown_tx.is_none(),
1235                AlreadyStartedSnafu { server: "HTTP" }
1236            );
1237
1238            let mut app = self.make_app();
1239            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1240                app = configurator
1241                    .configure_http(app, ())
1242                    .await
1243                    .context(OtherSnafu)?;
1244            }
1245            let app = self.build(app)?;
1246            let listener = tokio::net::TcpListener::bind(listening)
1247                .await
1248                .context(AddressBindSnafu { addr: listening })?
1249                .tap_io(|tcp_stream| {
1250                    if let Err(e) = tcp_stream.set_nodelay(true) {
1251                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1252                    }
1253                });
1254            let serve = axum::serve(
1255                listener,
1256                app.into_make_service_with_connect_info::<SocketAddr>(),
1257            );
1258
1259            // FIXME(yingwen): Support keepalive.
1260            // See:
1261            // - https://github.com/tokio-rs/axum/discussions/2939
1262            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1263            // let server = axum::Server::try_bind(&listening)
1264            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1265            //     .tcp_nodelay(true)
1266            //     // Enable TCP keepalive to close the dangling established connections.
1267            //     // It's configured to let the keepalive probes first send after the connection sits
1268            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1269            //     // So the connection will be closed after roughly 1 hour.
1270            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1271            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1272            //     .tcp_keepalive_retries(Some(6))
1273            //     .serve(app.into_make_service());
1274
1275            *shutdown_tx = Some(tx);
1276
1277            serve
1278        };
1279        let listening = serve.local_addr().context(InternalIoSnafu)?;
1280        info!("HTTP server is bound to {}", listening);
1281
1282        common_runtime::spawn_global(async move {
1283            if let Err(e) = serve
1284                .with_graceful_shutdown(rx.map(drop))
1285                .await
1286                .context(InternalIoSnafu)
1287            {
1288                error!(e; "Failed to shutdown http server");
1289            }
1290        });
1291
1292        self.bind_addr = Some(listening);
1293        Ok(())
1294    }
1295
1296    fn name(&self) -> &str {
1297        HTTP_SERVER
1298    }
1299
1300    fn bind_addr(&self) -> Option<SocketAddr> {
1301        self.bind_addr
1302    }
1303
1304    fn as_any(&self) -> &dyn std::any::Any {
1305        self
1306    }
1307}
1308
1309#[cfg(test)]
1310mod test {
1311    use std::future::pending;
1312    use std::io::Cursor;
1313    use std::sync::Arc;
1314
1315    use arrow_ipc::reader::FileReader;
1316    use arrow_schema::DataType;
1317    use axum::handler::Handler;
1318    use axum::http::StatusCode;
1319    use axum::routing::get;
1320    use common_query::Output;
1321    use common_recordbatch::RecordBatches;
1322    use datafusion_expr::LogicalPlan;
1323    use datatypes::prelude::*;
1324    use datatypes::schema::{ColumnSchema, Schema};
1325    use datatypes::vectors::{StringVector, UInt32Vector};
1326    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1327    use query::parser::PromQuery;
1328    use query::query_engine::DescribeResult;
1329    use session::context::QueryContextRef;
1330    use sql::statements::statement::Statement;
1331    use tokio::sync::mpsc;
1332    use tokio::time::Instant;
1333
1334    use super::*;
1335    use crate::http::test_helpers::TestClient;
1336    use crate::prom_remote_write::validation::validate_label_name;
1337    use crate::query_handler::sql::SqlQueryHandler;
1338
1339    struct DummyInstance {
1340        _tx: mpsc::Sender<(String, Vec<u8>)>,
1341    }
1342
1343    #[async_trait]
1344    impl SqlQueryHandler for DummyInstance {
1345        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1346            unimplemented!()
1347        }
1348
1349        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1350            unimplemented!()
1351        }
1352
1353        async fn do_exec_plan(
1354            &self,
1355            _stmt: Option<Statement>,
1356            _plan: LogicalPlan,
1357            _query_ctx: QueryContextRef,
1358        ) -> Result<Output> {
1359            unimplemented!()
1360        }
1361
1362        async fn do_describe(
1363            &self,
1364            _stmt: sql::statements::statement::Statement,
1365            _query_ctx: QueryContextRef,
1366        ) -> Result<Option<DescribeResult>> {
1367            unimplemented!()
1368        }
1369
1370        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1371            Ok(true)
1372        }
1373    }
1374
1375    fn timeout() -> DynamicTimeoutLayer {
1376        DynamicTimeoutLayer::new(Duration::from_millis(10))
1377    }
1378
1379    async fn forever() {
1380        pending().await
1381    }
1382
1383    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1384        make_test_app_custom(tx, HttpOptions::default())
1385    }
1386
1387    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1388        let instance = Arc::new(DummyInstance { _tx: tx });
1389        let server = HttpServerBuilder::new(options)
1390            .with_sql_handler(instance.clone())
1391            .build();
1392        server.build(server.make_app()).unwrap().route(
1393            "/test/timeout",
1394            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1395        )
1396    }
1397
1398    #[tokio::test]
1399    pub async fn test_cors() {
1400        // cors is on by default
1401        let (tx, _rx) = mpsc::channel(100);
1402        let app = make_test_app(tx);
1403        let client = TestClient::new(app).await;
1404
1405        let res = client.get("/health").send().await;
1406
1407        assert_eq!(res.status(), StatusCode::OK);
1408        assert_eq!(
1409            res.headers()
1410                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1411                .expect("expect cors header origin"),
1412            "*"
1413        );
1414
1415        let res = client.get("/v1/health").send().await;
1416
1417        assert_eq!(res.status(), StatusCode::OK);
1418        assert_eq!(
1419            res.headers()
1420                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1421                .expect("expect cors header origin"),
1422            "*"
1423        );
1424
1425        let res = client
1426            .options("/health")
1427            .header("Access-Control-Request-Headers", "x-greptime-auth")
1428            .header("Access-Control-Request-Method", "DELETE")
1429            .header("Origin", "https://example.com")
1430            .send()
1431            .await;
1432        assert_eq!(res.status(), StatusCode::OK);
1433        assert_eq!(
1434            res.headers()
1435                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1436                .expect("expect cors header origin"),
1437            "*"
1438        );
1439        assert_eq!(
1440            res.headers()
1441                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1442                .expect("expect cors header headers"),
1443            "*"
1444        );
1445        assert_eq!(
1446            res.headers()
1447                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1448                .expect("expect cors header methods"),
1449            "GET,POST,PUT,DELETE,HEAD"
1450        );
1451    }
1452
1453    #[tokio::test]
1454    pub async fn test_cors_custom_origins() {
1455        // cors is on by default
1456        let (tx, _rx) = mpsc::channel(100);
1457        let origin = "https://example.com";
1458
1459        let options = HttpOptions {
1460            cors_allowed_origins: vec![origin.to_string()],
1461            ..Default::default()
1462        };
1463
1464        let app = make_test_app_custom(tx, options);
1465        let client = TestClient::new(app).await;
1466
1467        let res = client.get("/health").header("Origin", origin).send().await;
1468
1469        assert_eq!(res.status(), StatusCode::OK);
1470        assert_eq!(
1471            res.headers()
1472                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1473                .expect("expect cors header origin"),
1474            origin
1475        );
1476
1477        let res = client
1478            .get("/health")
1479            .header("Origin", "https://notallowed.com")
1480            .send()
1481            .await;
1482
1483        assert_eq!(res.status(), StatusCode::OK);
1484        assert!(
1485            !res.headers()
1486                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1487        );
1488    }
1489
1490    #[tokio::test]
1491    pub async fn test_cors_disabled() {
1492        // cors is on by default
1493        let (tx, _rx) = mpsc::channel(100);
1494
1495        let options = HttpOptions {
1496            enable_cors: false,
1497            ..Default::default()
1498        };
1499
1500        let app = make_test_app_custom(tx, options);
1501        let client = TestClient::new(app).await;
1502
1503        let res = client.get("/health").send().await;
1504
1505        assert_eq!(res.status(), StatusCode::OK);
1506        assert!(
1507            !res.headers()
1508                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1509        );
1510    }
1511
1512    #[test]
1513    fn test_http_options_default() {
1514        let default = HttpOptions::default();
1515        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1516        assert_eq!(Duration::from_secs(0), default.timeout)
1517    }
1518
1519    #[tokio::test]
1520    async fn test_http_server_request_timeout() {
1521        common_telemetry::init_default_ut_logging();
1522
1523        let (tx, _rx) = mpsc::channel(100);
1524        let app = make_test_app(tx);
1525        let client = TestClient::new(app).await;
1526        let res = client.get("/test/timeout").send().await;
1527        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1528
1529        let now = Instant::now();
1530        let res = client
1531            .get("/test/timeout")
1532            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1533            .send()
1534            .await;
1535        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1536        let elapsed = now.elapsed();
1537        assert!(elapsed > Duration::from_millis(15));
1538
1539        tokio::time::timeout(
1540            Duration::from_millis(15),
1541            client
1542                .get("/test/timeout")
1543                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1544                .send(),
1545        )
1546        .await
1547        .unwrap_err();
1548
1549        tokio::time::timeout(
1550            Duration::from_millis(15),
1551            client
1552                .get("/test/timeout")
1553                .header(
1554                    GREPTIME_DB_HEADER_TIMEOUT,
1555                    humantime::format_duration(Duration::default()).to_string(),
1556                )
1557                .send(),
1558        )
1559        .await
1560        .unwrap_err();
1561    }
1562
1563    #[tokio::test]
1564    async fn test_schema_for_empty_response() {
1565        let column_schemas = vec![
1566            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1567            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1568        ];
1569        let schema = Arc::new(Schema::new(column_schemas));
1570
1571        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1572        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1573
1574        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1575        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1576            let json_output = &json_resp.output[0];
1577            if let GreptimeQueryOutput::Records(r) = json_output {
1578                assert_eq!(r.num_rows(), 0);
1579                assert_eq!(r.num_cols(), 2);
1580                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1581                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1582            } else {
1583                panic!("invalid output type");
1584            }
1585        } else {
1586            panic!("invalid format")
1587        }
1588    }
1589
1590    #[tokio::test]
1591    async fn test_recordbatches_conversion() {
1592        let column_schemas = vec![
1593            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1594            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1595        ];
1596        let schema = Arc::new(Schema::new(column_schemas));
1597        let columns: Vec<VectorRef> = vec![
1598            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1599            Arc::new(StringVector::from(vec![
1600                None,
1601                Some("hello"),
1602                Some("greptime"),
1603                None,
1604            ])),
1605        ];
1606        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1607
1608        for format in [
1609            ResponseFormat::GreptimedbV1,
1610            ResponseFormat::InfluxdbV1,
1611            ResponseFormat::Csv(true, true),
1612            ResponseFormat::Table,
1613            ResponseFormat::Arrow,
1614            ResponseFormat::Json,
1615            ResponseFormat::Null,
1616        ] {
1617            let recordbatches =
1618                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1619            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1620            let json_resp = match format {
1621                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1622                ResponseFormat::Csv(with_names, with_types) => {
1623                    CsvResponse::from_output(outputs, with_names, with_types).await
1624                }
1625                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1626                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1627                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1628                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1629                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1630            };
1631
1632            match json_resp {
1633                HttpResponse::GreptimedbV1(resp) => {
1634                    let json_output = &resp.output[0];
1635                    if let GreptimeQueryOutput::Records(r) = json_output {
1636                        assert_eq!(r.num_rows(), 4);
1637                        assert_eq!(r.num_cols(), 2);
1638                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1639                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1640                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1641                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1642                    } else {
1643                        panic!("invalid output type");
1644                    }
1645                }
1646                HttpResponse::InfluxdbV1(resp) => {
1647                    let json_output = &resp.results()[0];
1648                    assert_eq!(json_output.num_rows(), 4);
1649                    assert_eq!(json_output.num_cols(), 2);
1650                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1651                    assert_eq!(
1652                        json_output.series[0].values[0][0],
1653                        serde_json::Value::from(1)
1654                    );
1655                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1656                }
1657                HttpResponse::Csv(resp) => {
1658                    let output = &resp.output()[0];
1659                    if let GreptimeQueryOutput::Records(r) = output {
1660                        assert_eq!(r.num_rows(), 4);
1661                        assert_eq!(r.num_cols(), 2);
1662                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1663                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1664                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1665                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1666                    } else {
1667                        panic!("invalid output type");
1668                    }
1669                }
1670
1671                HttpResponse::Table(resp) => {
1672                    let output = &resp.output()[0];
1673                    if let GreptimeQueryOutput::Records(r) = output {
1674                        assert_eq!(r.num_rows(), 4);
1675                        assert_eq!(r.num_cols(), 2);
1676                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1677                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1678                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1679                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1680                    } else {
1681                        panic!("invalid output type");
1682                    }
1683                }
1684
1685                HttpResponse::Arrow(resp) => {
1686                    let output = resp.data;
1687                    let mut reader =
1688                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1689                    let schema = reader.schema();
1690                    assert_eq!(schema.fields[0].name(), "numbers");
1691                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1692                    assert_eq!(schema.fields[1].name(), "strings");
1693                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1694
1695                    let rb = reader.next().unwrap().expect("read record batch failed");
1696                    assert_eq!(rb.num_columns(), 2);
1697                    assert_eq!(rb.num_rows(), 4);
1698                }
1699
1700                HttpResponse::Json(resp) => {
1701                    let output = &resp.output()[0];
1702                    if let GreptimeQueryOutput::Records(r) = output {
1703                        assert_eq!(r.num_rows(), 4);
1704                        assert_eq!(r.num_cols(), 2);
1705                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1706                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1707                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1708                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1709                    } else {
1710                        panic!("invalid output type");
1711                    }
1712                }
1713
1714                HttpResponse::Null(resp) => {
1715                    assert_eq!(resp.rows(), 4);
1716                }
1717
1718                HttpResponse::Error(err) => unreachable!("{err:?}"),
1719            }
1720        }
1721    }
1722
1723    #[test]
1724    fn test_response_format_misc() {
1725        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1726        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1727        assert_eq!(
1728            ResponseFormat::parse("csv"),
1729            Some(ResponseFormat::Csv(false, false))
1730        );
1731        assert_eq!(
1732            ResponseFormat::parse("csvwithnames"),
1733            Some(ResponseFormat::Csv(true, false))
1734        );
1735        assert_eq!(
1736            ResponseFormat::parse("csvwithnamesandtypes"),
1737            Some(ResponseFormat::Csv(true, true))
1738        );
1739        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1740        assert_eq!(
1741            ResponseFormat::parse("greptimedb_v1"),
1742            Some(ResponseFormat::GreptimedbV1)
1743        );
1744        assert_eq!(
1745            ResponseFormat::parse("influxdb_v1"),
1746            Some(ResponseFormat::InfluxdbV1)
1747        );
1748        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1749        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1750
1751        // invalid formats
1752        assert_eq!(ResponseFormat::parse("invalid"), None);
1753        assert_eq!(ResponseFormat::parse(""), None);
1754        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1755
1756        // as str
1757        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1758        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1759        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1760        assert_eq!(ResponseFormat::Table.as_str(), "table");
1761        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1762        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1763        assert_eq!(ResponseFormat::Json.as_str(), "json");
1764        assert_eq!(ResponseFormat::Null.as_str(), "null");
1765        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1766    }
1767
1768    #[test]
1769    fn test_decode_label_name_strict() {
1770        let strict = PromValidationMode::Strict;
1771
1772        // Valid Prometheus label names
1773        assert!(strict.decode_label_name(b"__name__").is_ok());
1774        assert!(strict.decode_label_name(b"job").is_ok());
1775        assert!(strict.decode_label_name(b"instance").is_ok());
1776        assert!(strict.decode_label_name(b"_private").is_ok());
1777        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1778        assert!(strict.decode_label_name(b"abc123").is_ok());
1779        assert!(strict.decode_label_name(b"A").is_ok());
1780        assert!(strict.decode_label_name(b"_").is_ok());
1781
1782        // Invalid: starts with digit
1783        assert!(strict.decode_label_name(b"0abc").is_err());
1784        assert!(strict.decode_label_name(b"123").is_err());
1785
1786        // Invalid: contains special characters
1787        assert!(strict.decode_label_name(b"label-name").is_err());
1788        assert!(strict.decode_label_name(b"label.name").is_err());
1789        assert!(strict.decode_label_name(b"label name").is_err());
1790        assert!(strict.decode_label_name(b"label/name").is_err());
1791
1792        // Invalid: empty
1793        assert!(strict.decode_label_name(b"").is_err());
1794
1795        // Invalid: non-ASCII UTF-8
1796        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1797
1798        // Invalid UTF-8 bytes should fail
1799        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1800    }
1801
1802    #[test]
1803    fn test_decode_label_name_lossy() {
1804        let lossy = PromValidationMode::Lossy;
1805
1806        // Label name validation is always enforced.
1807        assert!(lossy.decode_label_name(b"__name__").is_ok());
1808        assert!(lossy.decode_label_name(b"label-name").is_err());
1809        assert!(lossy.decode_label_name(b"0abc").is_err());
1810
1811        // Invalid UTF-8 bytes fail the label-name byte check.
1812        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1813    }
1814
1815    #[test]
1816    fn test_decode_label_name_unchecked() {
1817        let unchecked = PromValidationMode::Unchecked;
1818
1819        // Label name validation is always enforced.
1820        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1821        assert!(unchecked.decode_label_name(b"label-name").is_err());
1822        assert!(unchecked.decode_label_name(b"0abc").is_err());
1823    }
1824
1825    #[test]
1826    fn test_is_valid_prom_label_name_bytes() {
1827        assert!(validate_label_name(b"__name__"));
1828        assert!(validate_label_name(b"job"));
1829        assert!(validate_label_name(b"_"));
1830        assert!(validate_label_name(b"A"));
1831        assert!(validate_label_name(b"abc123"));
1832        assert!(validate_label_name(b"_leading_underscore"));
1833
1834        assert!(!validate_label_name(b""));
1835        assert!(!validate_label_name(b"0starts_with_digit"));
1836        assert!(!validate_label_name(b"has-dash"));
1837        assert!(!validate_label_name(b"has.dot"));
1838        assert!(!validate_label_name(b"has space"));
1839        assert!(!validate_label_name(&[0xff, 0xfe]));
1840    }
1841}