Skip to main content

servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::{Arc, Mutex as StdMutex};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::readable_size::ReadableSize;
31use common_recordbatch::RecordBatch;
32use common_telemetry::{error, info};
33use common_time::Timestamp;
34use common_time::timestamp::TimeUnit;
35use datatypes::data_type::DataType;
36use datatypes::schema::SchemaRef;
37use event::{LogState, LogValidatorRef};
38use futures::FutureExt;
39use http::{HeaderValue, Method};
40use serde::{Deserialize, Serialize};
41use serde_json::Value;
42use snafu::{ResultExt, ensure};
43use tokio::sync::Mutex;
44use tokio::sync::oneshot::{self, Sender};
45use tonic::codegen::Service;
46use tower::{Layer, ServiceBuilder};
47use tower_http::compression::CompressionLayer;
48use tower_http::cors::{AllowOrigin, Any, CorsLayer};
49use tower_http::decompression::RequestDecompressionLayer;
50use tower_http::trace::TraceLayer;
51
52use self::authorize::AuthState;
53use self::result::table_result::TableResponse;
54use crate::elasticsearch;
55use crate::error::{
56    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu, Result,
57};
58use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
59use crate::http::otlp::OtlpState;
60use crate::http::prom_store::PromStoreState;
61use crate::http::prometheus::{
62    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
63    range_query, series_query,
64};
65use crate::http::result::arrow_result::ArrowResponse;
66use crate::http::result::csv_result::CsvResponse;
67use crate::http::result::error_result::ErrorResponse;
68use crate::http::result::greptime_result_v1::GreptimedbV1Response;
69use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
70use crate::http::result::json_result::JsonResponse;
71use crate::http::result::null_result::NullResponse;
72use crate::interceptor::LogIngestInterceptorRef;
73use crate::metrics::http_metrics_layer;
74use crate::metrics_handler::MetricsHandler;
75use crate::pending_rows_batcher::PendingRowsBatcher;
76use crate::prometheus_handler::PrometheusHandlerRef;
77use crate::query_handler::sql::ServerSqlQueryHandlerRef;
78use crate::query_handler::{
79    DashboardHandlerRef, InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
80    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
81    PromStoreProtocolHandlerRef,
82};
83use crate::request_memory_limiter::ServerMemoryLimiter;
84use crate::server::Server;
85
86pub mod authorize;
87#[cfg(feature = "dashboard")]
88mod dashboard;
89pub mod dyn_log;
90pub mod dyn_trace;
91pub mod event;
92pub mod extractor;
93pub mod handler;
94pub mod header;
95pub mod influxdb;
96pub mod jaeger;
97pub mod logs;
98pub mod loki;
99pub mod mem_prof;
100mod memory_limit;
101pub mod opentsdb;
102pub mod otlp;
103pub mod pprof;
104pub mod prom_store;
105pub mod prometheus;
106pub mod result;
107pub mod splunk;
108mod timeout;
109pub mod utils;
110
111use result::HttpOutputWriter;
112pub(crate) use timeout::DynamicTimeoutLayer;
113
114mod client_ip;
115use crate::prom_remote_write::validation::PromValidationMode;
116mod hints;
117mod read_preference;
118#[cfg(any(test, feature = "testing"))]
119pub mod test_helpers;
120
121pub const HTTP_API_VERSION: &str = "v1";
122pub const HTTP_API_PREFIX: &str = "/v1/";
123pub const HTTP_API_PREFIX_WITHOUT_TRAILING_SLASH: &str = "/v1";
124/// Default http body limit (64M).
125const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
126
127/// Authorization header
128pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
129
130// TODO(fys): This is a temporary workaround, it will be improved later
131pub static PUBLIC_API_PREFIX: [&str; 4] = [
132    "/v1/influxdb/ping",
133    "/v1/influxdb/health",
134    "/v1/health",
135    "/v1/splunk/services/collector/health",
136];
137
138#[derive(Default)]
139pub struct HttpServer {
140    router: StdMutex<Router>,
141    shutdown_tx: Mutex<Option<Sender<()>>>,
142    user_provider: Option<UserProviderRef>,
143    memory_limiter: ServerMemoryLimiter,
144
145    // server configs
146    options: HttpOptions,
147    bind_addr: Option<SocketAddr>,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(default)]
152pub struct HttpOptions {
153    pub addr: String,
154
155    #[serde(with = "humantime_serde")]
156    pub timeout: Duration,
157
158    #[serde(skip)]
159    pub disable_dashboard: bool,
160
161    pub body_limit: ReadableSize,
162
163    /// Validation mode while decoding Prometheus remote write requests.
164    pub prom_validation_mode: PromValidationMode,
165
166    pub cors_allowed_origins: Vec<String>,
167
168    pub enable_cors: bool,
169}
170
171impl Default for HttpOptions {
172    fn default() -> Self {
173        Self {
174            addr: "127.0.0.1:4000".to_string(),
175            timeout: Duration::from_secs(0),
176            disable_dashboard: false,
177            body_limit: DEFAULT_BODY_LIMIT,
178            cors_allowed_origins: Vec::new(),
179            enable_cors: true,
180            prom_validation_mode: PromValidationMode::Strict,
181        }
182    }
183}
184
185#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
186pub struct ColumnSchema {
187    name: String,
188    data_type: String,
189}
190
191impl ColumnSchema {
192    pub fn new(name: String, data_type: String) -> ColumnSchema {
193        ColumnSchema { name, data_type }
194    }
195}
196
197#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
198pub struct OutputSchema {
199    column_schemas: Vec<ColumnSchema>,
200}
201
202impl OutputSchema {
203    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
204        OutputSchema {
205            column_schemas: columns,
206        }
207    }
208}
209
210impl From<SchemaRef> for OutputSchema {
211    fn from(schema: SchemaRef) -> OutputSchema {
212        OutputSchema {
213            column_schemas: schema
214                .column_schemas()
215                .iter()
216                .map(|cs| ColumnSchema {
217                    name: cs.name.clone(),
218                    data_type: cs.data_type.name(),
219                })
220                .collect(),
221        }
222    }
223}
224
225#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
226pub struct HttpRecordsOutput {
227    schema: OutputSchema,
228    rows: Vec<Vec<Value>>,
229    // total_rows is equal to rows.len() in most cases,
230    // the Dashboard query result may be truncated, so we need to return the total_rows.
231    #[serde(default)]
232    total_rows: usize,
233
234    // plan level execution metrics
235    #[serde(skip_serializing_if = "HashMap::is_empty")]
236    #[serde(default)]
237    metrics: HashMap<String, Value>,
238}
239
240impl HttpRecordsOutput {
241    pub fn num_rows(&self) -> usize {
242        self.rows.len()
243    }
244
245    pub fn num_cols(&self) -> usize {
246        self.schema.column_schemas.len()
247    }
248
249    pub fn schema(&self) -> &OutputSchema {
250        &self.schema
251    }
252
253    pub fn rows(&self) -> &Vec<Vec<Value>> {
254        &self.rows
255    }
256}
257
258impl HttpRecordsOutput {
259    pub fn try_new(
260        schema: SchemaRef,
261        recordbatches: Vec<RecordBatch>,
262    ) -> std::result::Result<HttpRecordsOutput, Error> {
263        if recordbatches.is_empty() {
264            Ok(HttpRecordsOutput {
265                schema: OutputSchema::from(schema),
266                rows: vec![],
267                total_rows: 0,
268                metrics: Default::default(),
269            })
270        } else {
271            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
272            let mut rows = Vec::with_capacity(num_rows);
273
274            for recordbatch in recordbatches {
275                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
276                writer.write(recordbatch, &mut rows)?;
277            }
278
279            Ok(HttpRecordsOutput {
280                schema: OutputSchema::from(schema),
281                total_rows: rows.len(),
282                rows,
283                metrics: Default::default(),
284            })
285        }
286    }
287}
288
289#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
290#[serde(rename_all = "lowercase")]
291pub enum GreptimeQueryOutput {
292    AffectedRows(usize),
293    Records(HttpRecordsOutput),
294}
295
296/// It allows the results of SQL queries to be presented in different formats.
297#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
298pub enum ResponseFormat {
299    Arrow,
300    // (with_names, with_types)
301    Csv(bool, bool),
302    Table,
303    #[default]
304    GreptimedbV1,
305    InfluxdbV1,
306    Json,
307    Null,
308}
309
310impl ResponseFormat {
311    pub fn parse(s: &str) -> Option<Self> {
312        match s {
313            "arrow" => Some(ResponseFormat::Arrow),
314            "csv" => Some(ResponseFormat::Csv(false, false)),
315            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
316            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
317            "table" => Some(ResponseFormat::Table),
318            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
319            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
320            "json" => Some(ResponseFormat::Json),
321            "null" => Some(ResponseFormat::Null),
322            _ => None,
323        }
324    }
325
326    pub fn as_str(&self) -> &'static str {
327        match self {
328            ResponseFormat::Arrow => "arrow",
329            ResponseFormat::Csv(_, _) => "csv",
330            ResponseFormat::Table => "table",
331            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
332            ResponseFormat::InfluxdbV1 => "influxdb_v1",
333            ResponseFormat::Json => "json",
334            ResponseFormat::Null => "null",
335        }
336    }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
340pub enum Epoch {
341    Nanosecond,
342    Microsecond,
343    Millisecond,
344    Second,
345}
346
347impl Epoch {
348    pub fn parse(s: &str) -> Option<Epoch> {
349        // Both u and µ indicate microseconds.
350        // epoch = [ns,u,µ,ms,s],
351        // For details, see the Influxdb documents.
352        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
353        match s {
354            "ns" => Some(Epoch::Nanosecond),
355            "u" | "µ" => Some(Epoch::Microsecond),
356            "ms" => Some(Epoch::Millisecond),
357            "s" => Some(Epoch::Second),
358            _ => None, // just returns None for other cases
359        }
360    }
361
362    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
363        match self {
364            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
365            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
366            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
367            Epoch::Second => ts.convert_to(TimeUnit::Second),
368        }
369    }
370}
371
372impl Display for Epoch {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        match self {
375            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
376            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
377            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
378            Epoch::Second => write!(f, "Epoch::Second"),
379        }
380    }
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub enum HttpResponse {
385    Arrow(ArrowResponse),
386    Csv(CsvResponse),
387    Table(TableResponse),
388    Error(ErrorResponse),
389    GreptimedbV1(GreptimedbV1Response),
390    InfluxdbV1(InfluxdbV1Response),
391    Json(JsonResponse),
392    Null(NullResponse),
393}
394
395impl HttpResponse {
396    pub fn with_execution_time(self, execution_time: u64) -> Self {
397        match self {
398            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
399            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
400            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
401            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
402            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
403            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
404            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
405            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
406        }
407    }
408
409    pub fn with_limit(self, limit: usize) -> Self {
410        match self {
411            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
412            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
413            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
414            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
415            _ => self,
416        }
417    }
418}
419
420pub fn process_with_limit(
421    mut outputs: Vec<GreptimeQueryOutput>,
422    limit: usize,
423) -> Vec<GreptimeQueryOutput> {
424    outputs
425        .drain(..)
426        .map(|data| match data {
427            GreptimeQueryOutput::Records(mut records) => {
428                if records.rows.len() > limit {
429                    records.rows.truncate(limit);
430                    records.total_rows = limit;
431                }
432                GreptimeQueryOutput::Records(records)
433            }
434            _ => data,
435        })
436        .collect()
437}
438
439impl IntoResponse for HttpResponse {
440    fn into_response(self) -> Response {
441        match self {
442            HttpResponse::Arrow(resp) => resp.into_response(),
443            HttpResponse::Csv(resp) => resp.into_response(),
444            HttpResponse::Table(resp) => resp.into_response(),
445            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
446            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
447            HttpResponse::Json(resp) => resp.into_response(),
448            HttpResponse::Null(resp) => resp.into_response(),
449            HttpResponse::Error(resp) => resp.into_response(),
450        }
451    }
452}
453
454impl From<ArrowResponse> for HttpResponse {
455    fn from(value: ArrowResponse) -> Self {
456        HttpResponse::Arrow(value)
457    }
458}
459
460impl From<CsvResponse> for HttpResponse {
461    fn from(value: CsvResponse) -> Self {
462        HttpResponse::Csv(value)
463    }
464}
465
466impl From<TableResponse> for HttpResponse {
467    fn from(value: TableResponse) -> Self {
468        HttpResponse::Table(value)
469    }
470}
471
472impl From<ErrorResponse> for HttpResponse {
473    fn from(value: ErrorResponse) -> Self {
474        HttpResponse::Error(value)
475    }
476}
477
478impl From<GreptimedbV1Response> for HttpResponse {
479    fn from(value: GreptimedbV1Response) -> Self {
480        HttpResponse::GreptimedbV1(value)
481    }
482}
483
484impl From<InfluxdbV1Response> for HttpResponse {
485    fn from(value: InfluxdbV1Response) -> Self {
486        HttpResponse::InfluxdbV1(value)
487    }
488}
489
490impl From<JsonResponse> for HttpResponse {
491    fn from(value: JsonResponse) -> Self {
492        HttpResponse::Json(value)
493    }
494}
495
496impl From<NullResponse> for HttpResponse {
497    fn from(value: NullResponse) -> Self {
498        HttpResponse::Null(value)
499    }
500}
501
502#[derive(Clone)]
503pub struct ApiState {
504    pub sql_handler: ServerSqlQueryHandlerRef,
505}
506
507#[derive(Clone)]
508pub struct GreptimeOptionsConfigState {
509    pub greptime_config_options: String,
510}
511
512#[derive(Clone)]
513pub struct DashboardState {
514    pub handler: DashboardHandlerRef,
515}
516
517pub struct HttpServerBuilder {
518    options: HttpOptions,
519    user_provider: Option<UserProviderRef>,
520    router: Router,
521    memory_limiter: ServerMemoryLimiter,
522}
523
524impl HttpServerBuilder {
525    pub fn new(options: HttpOptions) -> Self {
526        Self {
527            options,
528            user_provider: None,
529            router: Router::new(),
530            memory_limiter: ServerMemoryLimiter::default(),
531        }
532    }
533
534    /// Set a global memory limiter for all server protocols.
535    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
536        self.memory_limiter = limiter;
537        self
538    }
539
540    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
541        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
542
543        Self {
544            router: self
545                .router
546                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
547            ..self
548        }
549    }
550
551    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
552        let logs_router = HttpServer::route_logs(logs_handler);
553
554        Self {
555            router: self
556                .router
557                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
558            ..self
559        }
560    }
561
562    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
563        Self {
564            router: self.router.nest(
565                &format!("/{HTTP_API_VERSION}/opentsdb"),
566                HttpServer::route_opentsdb(handler),
567            ),
568            ..self
569        }
570    }
571
572    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
573        Self {
574            router: self.router.nest(
575                &format!("/{HTTP_API_VERSION}/influxdb"),
576                HttpServer::route_influxdb(handler),
577            ),
578            ..self
579        }
580    }
581
582    pub fn with_prom_handler(
583        self,
584        handler: PromStoreProtocolHandlerRef,
585        pipeline_handler: Option<PipelineHandlerRef>,
586        prom_store_with_metric_engine: bool,
587        prom_validation_mode: PromValidationMode,
588        pending_rows_batcher: Option<Arc<PendingRowsBatcher>>,
589    ) -> Self {
590        let state = PromStoreState {
591            prom_store_handler: handler,
592            pipeline_handler,
593            prom_store_with_metric_engine,
594            prom_validation_mode,
595            pending_rows_batcher,
596        };
597
598        Self {
599            router: self.router.nest(
600                &format!("/{HTTP_API_VERSION}/prometheus"),
601                HttpServer::route_prom(state),
602            ),
603            ..self
604        }
605    }
606
607    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
608        Self {
609            router: self.router.nest(
610                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
611                HttpServer::route_prometheus(handler),
612            ),
613            ..self
614        }
615    }
616
617    pub fn with_otlp_handler(
618        self,
619        handler: OpenTelemetryProtocolHandlerRef,
620        with_metric_engine: bool,
621    ) -> Self {
622        Self {
623            router: self.router.nest(
624                &format!("/{HTTP_API_VERSION}/otlp"),
625                HttpServer::route_otlp(handler, with_metric_engine),
626            ),
627            ..self
628        }
629    }
630
631    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
632        Self {
633            user_provider: Some(user_provider),
634            ..self
635        }
636    }
637
638    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
639        Self {
640            router: self.router.merge(HttpServer::route_metrics(handler)),
641            ..self
642        }
643    }
644
645    pub fn with_log_ingest_handler(
646        self,
647        handler: PipelineHandlerRef,
648        validator: Option<LogValidatorRef>,
649        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
650    ) -> Self {
651        let log_state = LogState {
652            log_handler: handler,
653            log_validator: validator,
654            ingest_interceptor,
655        };
656
657        let router = self.router.nest(
658            &format!("/{HTTP_API_VERSION}"),
659            HttpServer::route_pipelines(log_state.clone()),
660        );
661        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
662        let router = router.nest(
663            &format!("/{HTTP_API_VERSION}/events"),
664            #[allow(deprecated)]
665            HttpServer::route_log_deprecated(log_state.clone()),
666        );
667
668        let router = router.nest(
669            &format!("/{HTTP_API_VERSION}/loki"),
670            HttpServer::route_loki(log_state.clone()),
671        );
672
673        let router = router.nest(
674            &format!("/{HTTP_API_VERSION}/elasticsearch"),
675            HttpServer::route_elasticsearch(log_state.clone()),
676        );
677
678        let router = router.nest(
679            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
680            Router::new()
681                .route("/", routing::get(elasticsearch::handle_get_version))
682                .with_state(log_state.clone()),
683        );
684
685        let router = router.nest(
686            &format!("/{HTTP_API_VERSION}/splunk"),
687            HttpServer::route_splunk(log_state),
688        );
689
690        Self { router, ..self }
691    }
692
693    pub fn with_greptime_config_options(self, opts: String) -> Self {
694        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
695            greptime_config_options: opts,
696        });
697
698        Self {
699            router: self.router.merge(config_router),
700            ..self
701        }
702    }
703
704    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
705        Self {
706            router: self.router.nest(
707                &format!("/{HTTP_API_VERSION}/jaeger"),
708                HttpServer::route_jaeger(handler),
709            ),
710            ..self
711        }
712    }
713
714    pub fn with_dashboard_handler(self, handler: DashboardHandlerRef) -> Self {
715        Self {
716            router: self.router.nest(
717                &format!("/{HTTP_API_VERSION}/dashboards"),
718                HttpServer::route_dashboard(handler),
719            ),
720            ..self
721        }
722    }
723
724    pub fn with_extra_router(self, router: Router) -> Self {
725        Self {
726            router: self.router.merge(router),
727            ..self
728        }
729    }
730
731    pub fn add_layer<L>(self, layer: L) -> Self
732    where
733        L: Layer<Route> + Clone + Send + Sync + 'static,
734        L::Service: Service<Request> + Clone + Send + Sync + 'static,
735        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
736        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
737        <L::Service as Service<Request>>::Future: Send + 'static,
738    {
739        Self {
740            router: self.router.layer(layer),
741            ..self
742        }
743    }
744
745    pub fn build(self) -> HttpServer {
746        HttpServer {
747            options: self.options,
748            user_provider: self.user_provider,
749            shutdown_tx: Mutex::new(None),
750            router: StdMutex::new(self.router),
751            bind_addr: None,
752            memory_limiter: self.memory_limiter,
753        }
754    }
755}
756
757impl HttpServer {
758    /// Gets the router and adds necessary root routes (health, status, dashboard).
759    pub fn make_app(&self) -> Router {
760        let mut router = {
761            let router = self.router.lock().unwrap();
762            router.clone()
763        };
764
765        router = router
766            .route("/", routing::get(handler::index))
767            .route(
768                "/health",
769                routing::get(handler::health).post(handler::health),
770            )
771            .route(
772                &format!("/{HTTP_API_VERSION}/health"),
773                routing::get(handler::health).post(handler::health),
774            )
775            .route(
776                "/ready",
777                routing::get(handler::health).post(handler::health),
778            );
779
780        router = router.route("/status", routing::get(handler::status));
781
782        #[cfg(feature = "dashboard")]
783        {
784            if !self.options.disable_dashboard {
785                info!("Enable dashboard service at '/dashboard'");
786                // redirect /dashboard to /dashboard/
787                router = router.route(
788                    "/dashboard",
789                    routing::get(|uri: axum::http::uri::Uri| async move {
790                        let path = uri.path();
791                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
792
793                        let new_uri = format!("{}/{}", path, query);
794                        axum::response::Redirect::permanent(&new_uri)
795                    }),
796                );
797
798                // "/dashboard" and "/dashboard/" are two different paths in Axum.
799                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
800                // So we explicitly route "/dashboard/" here.
801                router = router
802                    .route(
803                        "/dashboard/",
804                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
805                    )
806                    .route(
807                        "/dashboard/{*x}",
808                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
809                    );
810            }
811        }
812
813        // Add a layer to collect HTTP metrics for axum.
814        router = router.route_layer(middleware::from_fn(http_metrics_layer));
815
816        router
817    }
818
819    /// Attaches middlewares and debug routes to the router.
820    /// Callers should call this method after [HttpServer::make_app()].
821    pub fn build(&self, router: Router) -> Result<Router> {
822        let timeout_layer = if self.options.timeout != Duration::default() {
823            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
824        } else {
825            info!("HTTP server timeout is disabled");
826            None
827        };
828        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
829            Some(
830                ServiceBuilder::new()
831                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
832            )
833        } else {
834            info!("HTTP server body limit is disabled");
835            None
836        };
837        let cors_layer = if self.options.enable_cors {
838            Some(
839                CorsLayer::new()
840                    .allow_methods([
841                        Method::GET,
842                        Method::POST,
843                        Method::PUT,
844                        Method::DELETE,
845                        Method::HEAD,
846                    ])
847                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
848                        AllowOrigin::from(Any)
849                    } else {
850                        AllowOrigin::from(
851                            self.options
852                                .cors_allowed_origins
853                                .iter()
854                                .map(|s| {
855                                    HeaderValue::from_str(s.as_str())
856                                        .context(InvalidHeaderValueSnafu)
857                                })
858                                .collect::<Result<Vec<HeaderValue>>>()?,
859                        )
860                    })
861                    .allow_headers(Any),
862            )
863        } else {
864            info!("HTTP server cross-origin is disabled");
865            None
866        };
867
868        Ok(router
869            // middlewares
870            .layer(
871                ServiceBuilder::new()
872                    // disable on failure tracing. because printing out isn't very helpful,
873                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
874                    .layer(TraceLayer::new_for_http().on_failure(()))
875                    .option_layer(cors_layer)
876                    .option_layer(timeout_layer)
877                    .option_layer(body_limit_layer)
878                    // memory limit layer - must be before body is consumed
879                    .layer(middleware::from_fn_with_state(
880                        self.memory_limiter.clone(),
881                        memory_limit::memory_limit_middleware,
882                    ))
883                    // auth layer
884                    .layer(middleware::from_fn_with_state(
885                        AuthState::new(self.user_provider.clone()),
886                        authorize::check_http_auth,
887                    ))
888                    .layer(middleware::from_fn(hints::extract_hints))
889                    .layer(middleware::from_fn(client_ip::log_error_with_client_ip))
890                    .layer(middleware::from_fn(
891                        read_preference::extract_read_preference,
892                    )),
893            )
894            // Handlers for debug, we don't expect a timeout.
895            .nest(
896                "/debug",
897                Router::new()
898                    // handler for changing log level dynamically
899                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
900                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
901                    .nest(
902                        "/prof",
903                        Router::new()
904                            .route("/cpu", routing::post(pprof::pprof_handler))
905                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
906                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
907                            .route(
908                                "/mem/activate",
909                                routing::post(mem_prof::activate_heap_prof_handler),
910                            )
911                            .route(
912                                "/mem/deactivate",
913                                routing::post(mem_prof::deactivate_heap_prof_handler),
914                            )
915                            .route(
916                                "/mem/status",
917                                routing::get(mem_prof::heap_prof_status_handler),
918                            ) // jemalloc gdump flag status and toggle
919                            .route(
920                                "/mem/gdump",
921                                routing::get(mem_prof::gdump_status_handler)
922                                    .post(mem_prof::gdump_toggle_handler),
923                            ),
924                    ),
925            ))
926    }
927
928    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
929        Router::new()
930            .route("/metrics", routing::get(handler::metrics))
931            .with_state(metrics_handler)
932    }
933
934    fn route_loki<S>(log_state: LogState) -> Router<S> {
935        Router::new()
936            .route("/api/v1/push", routing::post(loki::loki_ingest))
937            .layer(
938                ServiceBuilder::new()
939                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
940            )
941            .with_state(log_state)
942    }
943
944    fn route_splunk<S>(log_state: LogState) -> Router<S> {
945        Router::new()
946            .route(
947                "/services/collector/health",
948                routing::get(splunk::handle_health),
949            )
950            .route(
951                "/services/collector/health/1.0",
952                routing::get(splunk::handle_health),
953            )
954            // The event endpoint plus its base and versioned aliases all serve
955            // the same handler (Splunk JSON event protocol).
956            .route(
957                "/services/collector/event",
958                routing::post(splunk::handle_event),
959            )
960            .route("/services/collector", routing::post(splunk::handle_event))
961            .route(
962                "/services/collector/event/1.0",
963                routing::post(splunk::handle_event),
964            )
965            .layer(
966                ServiceBuilder::new()
967                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
968            )
969            .with_state(log_state)
970    }
971
972    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
973        Router::new()
974            // Return fake responsefor HEAD '/' request.
975            .route(
976                "/",
977                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
978            )
979            // Return fake response for Elasticsearch version request.
980            .route("/", routing::get(elasticsearch::handle_get_version))
981            // Return fake response for Elasticsearch license request.
982            .route("/_license", routing::get(elasticsearch::handle_get_license))
983            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
984            .route(
985                "/{index}/_bulk",
986                routing::post(elasticsearch::handle_bulk_api_with_index),
987            )
988            // Return fake response for Elasticsearch ilm request.
989            .route(
990                "/_ilm/policy/{*path}",
991                routing::any((
992                    HttpStatusCode::OK,
993                    elasticsearch::elasticsearch_headers(),
994                    axum::Json(serde_json::json!({})),
995                )),
996            )
997            // Return fake response for Elasticsearch index template request.
998            .route(
999                "/_index_template/{*path}",
1000                routing::any((
1001                    HttpStatusCode::OK,
1002                    elasticsearch::elasticsearch_headers(),
1003                    axum::Json(serde_json::json!({})),
1004                )),
1005            )
1006            // Return fake response for Elasticsearch ingest pipeline request.
1007            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
1008            .route(
1009                "/_ingest/{*path}",
1010                routing::any((
1011                    HttpStatusCode::OK,
1012                    elasticsearch::elasticsearch_headers(),
1013                    axum::Json(serde_json::json!({})),
1014                )),
1015            )
1016            // Return fake response for Elasticsearch nodes discovery request.
1017            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
1018            .route(
1019                "/_nodes/{*path}",
1020                routing::any((
1021                    HttpStatusCode::OK,
1022                    elasticsearch::elasticsearch_headers(),
1023                    axum::Json(serde_json::json!({})),
1024                )),
1025            )
1026            // Return fake response for Logstash APIs requests.
1027            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1028            .route(
1029                "/logstash/{*path}",
1030                routing::any((
1031                    HttpStatusCode::OK,
1032                    elasticsearch::elasticsearch_headers(),
1033                    axum::Json(serde_json::json!({})),
1034                )),
1035            )
1036            .route(
1037                "/_logstash/{*path}",
1038                routing::any((
1039                    HttpStatusCode::OK,
1040                    elasticsearch::elasticsearch_headers(),
1041                    axum::Json(serde_json::json!({})),
1042                )),
1043            )
1044            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1045            .with_state(log_state)
1046    }
1047
1048    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1049    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1050        Router::new()
1051            .route("/logs", routing::post(event::log_ingester))
1052            .route(
1053                "/pipelines/{pipeline_name}",
1054                routing::get(event::query_pipeline),
1055            )
1056            .route(
1057                "/pipelines/{pipeline_name}",
1058                routing::post(event::add_pipeline),
1059            )
1060            .route(
1061                "/pipelines/{pipeline_name}",
1062                routing::delete(event::delete_pipeline),
1063            )
1064            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1065            .layer(
1066                ServiceBuilder::new()
1067                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1068            )
1069            .with_state(log_state)
1070    }
1071
1072    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1073        Router::new()
1074            .route("/ingest", routing::post(event::log_ingester))
1075            .route(
1076                "/pipelines/{pipeline_name}",
1077                routing::get(event::query_pipeline),
1078            )
1079            .route(
1080                "/pipelines/{pipeline_name}/ddl",
1081                routing::get(event::query_pipeline_ddl),
1082            )
1083            .route(
1084                "/pipelines/{pipeline_name}",
1085                routing::post(event::add_pipeline),
1086            )
1087            .route(
1088                "/pipelines/{pipeline_name}",
1089                routing::delete(event::delete_pipeline),
1090            )
1091            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1092            .layer(
1093                ServiceBuilder::new()
1094                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1095            )
1096            .with_state(log_state)
1097    }
1098
1099    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1100        Router::new()
1101            .route("/sql", routing::get(handler::sql).post(handler::sql))
1102            .route(
1103                "/sql/parse",
1104                routing::get(handler::sql_parse).post(handler::sql_parse),
1105            )
1106            .route(
1107                "/sql/format",
1108                routing::get(handler::sql_format).post(handler::sql_format),
1109            )
1110            .route(
1111                "/promql",
1112                routing::get(handler::promql).post(handler::promql),
1113            )
1114            .with_state(api_state)
1115    }
1116
1117    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1118        Router::new()
1119            .route("/logs", routing::get(logs::logs).post(logs::logs))
1120            .with_state(log_handler)
1121    }
1122
1123    /// Route Prometheus [HTTP API].
1124    ///
1125    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1126    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1127        Router::new()
1128            .route(
1129                "/format_query",
1130                routing::post(format_query).get(format_query),
1131            )
1132            .route("/status/buildinfo", routing::get(build_info_query))
1133            .route("/query", routing::post(instant_query).get(instant_query))
1134            .route("/query_range", routing::post(range_query).get(range_query))
1135            .route("/labels", routing::post(labels_query).get(labels_query))
1136            .route("/series", routing::post(series_query).get(series_query))
1137            .route("/parse_query", routing::post(parse_query).get(parse_query))
1138            .route(
1139                "/label/{label_name}/values",
1140                routing::get(label_values_query),
1141            )
1142            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1143            .with_state(prometheus_handler)
1144    }
1145
1146    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1147    /// called `prom_store`.
1148    ///
1149    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1150    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1151    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1152        Router::new()
1153            .route("/read", routing::post(prom_store::remote_read))
1154            .route("/write", routing::post(prom_store::remote_write))
1155            .with_state(state)
1156    }
1157
1158    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1159        Router::new()
1160            .route("/write", routing::post(influxdb_write_v1))
1161            .route("/api/v2/write", routing::post(influxdb_write_v2))
1162            .layer(
1163                ServiceBuilder::new()
1164                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1165            )
1166            .route("/ping", routing::get(influxdb_ping))
1167            .route("/health", routing::get(influxdb_health))
1168            .with_state(influxdb_handler)
1169    }
1170
1171    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1172        Router::new()
1173            .route("/api/put", routing::post(opentsdb::put))
1174            .with_state(opentsdb_handler)
1175    }
1176
1177    fn route_otlp<S>(
1178        otlp_handler: OpenTelemetryProtocolHandlerRef,
1179        with_metric_engine: bool,
1180    ) -> Router<S> {
1181        Router::new()
1182            .route("/v1/metrics", routing::post(otlp::metrics))
1183            .route("/v1/traces", routing::post(otlp::traces))
1184            .route("/v1/logs", routing::post(otlp::logs))
1185            .layer(
1186                ServiceBuilder::new()
1187                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1188            )
1189            .with_state(OtlpState {
1190                with_metric_engine,
1191                handler: otlp_handler,
1192            })
1193    }
1194
1195    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1196        Router::new()
1197            .route("/config", routing::get(handler::config))
1198            .with_state(state)
1199    }
1200
1201    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1202        Router::new()
1203            .route("/api/services", routing::get(jaeger::handle_get_services))
1204            .route(
1205                "/api/services/{service_name}/operations",
1206                routing::get(jaeger::handle_get_operations_by_service),
1207            )
1208            .route(
1209                "/api/operations",
1210                routing::get(jaeger::handle_get_operations),
1211            )
1212            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1213            .route(
1214                "/api/traces/{trace_id}",
1215                routing::get(jaeger::handle_get_trace),
1216            )
1217            .with_state(handler)
1218    }
1219
1220    #[cfg(feature = "dashboard")]
1221    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1222        use crate::http::dashboard::{add_dashboard, delete_dashboard, list_dashboards};
1223
1224        Router::new()
1225            .route("/", routing::get(list_dashboards))
1226            .route("/{dashboard_name}", routing::post(add_dashboard))
1227            .route("/{dashboard_name}", routing::delete(delete_dashboard))
1228            .layer(
1229                ServiceBuilder::new()
1230                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1231            )
1232            .with_state(DashboardState { handler })
1233    }
1234
1235    #[cfg(not(feature = "dashboard"))]
1236    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1237        Router::new().with_state(DashboardState { handler })
1238    }
1239}
1240
1241pub const HTTP_SERVER: &str = "HTTP_SERVER";
1242
1243#[async_trait]
1244impl Server for HttpServer {
1245    async fn shutdown(&self) -> Result<()> {
1246        let mut shutdown_tx = self.shutdown_tx.lock().await;
1247        if let Some(tx) = shutdown_tx.take()
1248            && tx.send(()).is_err()
1249        {
1250            info!("Receiver dropped, the HTTP server has already exited");
1251        }
1252        info!("Shutdown HTTP server");
1253
1254        Ok(())
1255    }
1256
1257    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1258        let (tx, rx) = oneshot::channel();
1259        let serve = {
1260            let mut shutdown_tx = self.shutdown_tx.lock().await;
1261            ensure!(
1262                shutdown_tx.is_none(),
1263                AlreadyStartedSnafu { server: "HTTP" }
1264            );
1265
1266            let app = self.build(self.make_app())?;
1267            let listener = tokio::net::TcpListener::bind(listening)
1268                .await
1269                .context(AddressBindSnafu { addr: listening })?
1270                .tap_io(|tcp_stream| {
1271                    if let Err(e) = tcp_stream.set_nodelay(true) {
1272                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1273                    }
1274                });
1275            let serve = axum::serve(
1276                listener,
1277                app.into_make_service_with_connect_info::<SocketAddr>(),
1278            );
1279
1280            // FIXME(yingwen): Support keepalive.
1281            // See:
1282            // - https://github.com/tokio-rs/axum/discussions/2939
1283            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1284            // let server = axum::Server::try_bind(&listening)
1285            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1286            //     .tcp_nodelay(true)
1287            //     // Enable TCP keepalive to close the dangling established connections.
1288            //     // It's configured to let the keepalive probes first send after the connection sits
1289            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1290            //     // So the connection will be closed after roughly 1 hour.
1291            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1292            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1293            //     .tcp_keepalive_retries(Some(6))
1294            //     .serve(app.into_make_service());
1295
1296            *shutdown_tx = Some(tx);
1297
1298            serve
1299        };
1300        let listening = serve.local_addr().context(InternalIoSnafu)?;
1301        info!("HTTP server is bound to {}", listening);
1302
1303        common_runtime::spawn_global(async move {
1304            if let Err(e) = serve
1305                .with_graceful_shutdown(rx.map(drop))
1306                .await
1307                .context(InternalIoSnafu)
1308            {
1309                error!(e; "Failed to shutdown http server");
1310            }
1311        });
1312
1313        self.bind_addr = Some(listening);
1314        Ok(())
1315    }
1316
1317    fn name(&self) -> &str {
1318        HTTP_SERVER
1319    }
1320
1321    fn bind_addr(&self) -> Option<SocketAddr> {
1322        self.bind_addr
1323    }
1324
1325    fn as_any(&self) -> &dyn std::any::Any {
1326        self
1327    }
1328}
1329
1330#[cfg(test)]
1331mod test {
1332    use std::future::pending;
1333    use std::io::Cursor;
1334    use std::sync::Arc;
1335
1336    use arrow_ipc::reader::StreamReader;
1337    use arrow_schema::DataType;
1338    use axum::handler::Handler;
1339    use axum::http::StatusCode;
1340    use axum::routing::get;
1341    use common_query::Output;
1342    use common_recordbatch::RecordBatches;
1343    use datafusion_expr::LogicalPlan;
1344    use datatypes::prelude::*;
1345    use datatypes::schema::{ColumnSchema, Schema};
1346    use datatypes::vectors::{StringVector, UInt32Vector};
1347    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1348    use query::parser::PromQuery;
1349    use query::query_engine::DescribeResult;
1350    use session::context::QueryContextRef;
1351    use sql::statements::statement::Statement;
1352    use tokio::sync::mpsc;
1353    use tokio::time::Instant;
1354
1355    use super::*;
1356    use crate::http::test_helpers::TestClient;
1357    use crate::prom_remote_write::validation::validate_label_name;
1358    use crate::query_handler::sql::SqlQueryHandler;
1359
1360    struct DummyInstance {
1361        _tx: mpsc::Sender<(String, Vec<u8>)>,
1362    }
1363
1364    #[async_trait]
1365    impl SqlQueryHandler for DummyInstance {
1366        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1367            unimplemented!()
1368        }
1369
1370        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1371            unimplemented!()
1372        }
1373
1374        async fn do_exec_plan(
1375            &self,
1376            _plan: LogicalPlan,
1377            _stmt: Option<Statement>,
1378            _query_ctx: QueryContextRef,
1379        ) -> Result<Output> {
1380            unimplemented!()
1381        }
1382
1383        async fn do_describe(
1384            &self,
1385            _stmt: sql::statements::statement::Statement,
1386            _query_ctx: QueryContextRef,
1387        ) -> Result<Option<DescribeResult>> {
1388            unimplemented!()
1389        }
1390
1391        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1392            Ok(true)
1393        }
1394    }
1395
1396    fn timeout() -> DynamicTimeoutLayer {
1397        DynamicTimeoutLayer::new(Duration::from_millis(10))
1398    }
1399
1400    async fn forever() {
1401        pending().await
1402    }
1403
1404    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1405        make_test_app_custom(tx, HttpOptions::default())
1406    }
1407
1408    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1409        let instance = Arc::new(DummyInstance { _tx: tx });
1410        let server = HttpServerBuilder::new(options)
1411            .with_sql_handler(instance.clone())
1412            .build();
1413        server.build(server.make_app()).unwrap().route(
1414            "/test/timeout",
1415            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1416        )
1417    }
1418
1419    #[tokio::test]
1420    pub async fn test_cors() {
1421        // cors is on by default
1422        let (tx, _rx) = mpsc::channel(100);
1423        let app = make_test_app(tx);
1424        let client = TestClient::new(app).await;
1425
1426        let res = client.get("/health").send().await;
1427
1428        assert_eq!(res.status(), StatusCode::OK);
1429        assert_eq!(
1430            res.headers()
1431                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1432                .expect("expect cors header origin"),
1433            "*"
1434        );
1435
1436        let res = client.get("/v1/health").send().await;
1437
1438        assert_eq!(res.status(), StatusCode::OK);
1439        assert_eq!(
1440            res.headers()
1441                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1442                .expect("expect cors header origin"),
1443            "*"
1444        );
1445
1446        let res = client
1447            .options("/health")
1448            .header("Access-Control-Request-Headers", "x-greptime-auth")
1449            .header("Access-Control-Request-Method", "DELETE")
1450            .header("Origin", "https://example.com")
1451            .send()
1452            .await;
1453        assert_eq!(res.status(), StatusCode::OK);
1454        assert_eq!(
1455            res.headers()
1456                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1457                .expect("expect cors header origin"),
1458            "*"
1459        );
1460        assert_eq!(
1461            res.headers()
1462                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1463                .expect("expect cors header headers"),
1464            "*"
1465        );
1466        assert_eq!(
1467            res.headers()
1468                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1469                .expect("expect cors header methods"),
1470            "GET,POST,PUT,DELETE,HEAD"
1471        );
1472    }
1473
1474    #[tokio::test]
1475    pub async fn test_cors_custom_origins() {
1476        // cors is on by default
1477        let (tx, _rx) = mpsc::channel(100);
1478        let origin = "https://example.com";
1479
1480        let options = HttpOptions {
1481            cors_allowed_origins: vec![origin.to_string()],
1482            ..Default::default()
1483        };
1484
1485        let app = make_test_app_custom(tx, options);
1486        let client = TestClient::new(app).await;
1487
1488        let res = client.get("/health").header("Origin", origin).send().await;
1489
1490        assert_eq!(res.status(), StatusCode::OK);
1491        assert_eq!(
1492            res.headers()
1493                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1494                .expect("expect cors header origin"),
1495            origin
1496        );
1497
1498        let res = client
1499            .get("/health")
1500            .header("Origin", "https://notallowed.com")
1501            .send()
1502            .await;
1503
1504        assert_eq!(res.status(), StatusCode::OK);
1505        assert!(
1506            !res.headers()
1507                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1508        );
1509    }
1510
1511    #[tokio::test]
1512    pub async fn test_cors_disabled() {
1513        // cors is on by default
1514        let (tx, _rx) = mpsc::channel(100);
1515
1516        let options = HttpOptions {
1517            enable_cors: false,
1518            ..Default::default()
1519        };
1520
1521        let app = make_test_app_custom(tx, options);
1522        let client = TestClient::new(app).await;
1523
1524        let res = client.get("/health").send().await;
1525
1526        assert_eq!(res.status(), StatusCode::OK);
1527        assert!(
1528            !res.headers()
1529                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1530        );
1531    }
1532
1533    #[test]
1534    fn test_http_options_default() {
1535        let default = HttpOptions::default();
1536        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1537        assert_eq!(Duration::from_secs(0), default.timeout)
1538    }
1539
1540    #[tokio::test]
1541    async fn test_http_server_request_timeout() {
1542        common_telemetry::init_default_ut_logging();
1543
1544        let (tx, _rx) = mpsc::channel(100);
1545        let app = make_test_app(tx);
1546        let client = TestClient::new(app).await;
1547        let res = client.get("/test/timeout").send().await;
1548        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1549
1550        let now = Instant::now();
1551        let res = client
1552            .get("/test/timeout")
1553            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1554            .send()
1555            .await;
1556        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1557        let elapsed = now.elapsed();
1558        assert!(elapsed > Duration::from_millis(15));
1559
1560        tokio::time::timeout(
1561            Duration::from_millis(15),
1562            client
1563                .get("/test/timeout")
1564                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1565                .send(),
1566        )
1567        .await
1568        .unwrap_err();
1569
1570        tokio::time::timeout(
1571            Duration::from_millis(15),
1572            client
1573                .get("/test/timeout")
1574                .header(
1575                    GREPTIME_DB_HEADER_TIMEOUT,
1576                    humantime::format_duration(Duration::default()).to_string(),
1577                )
1578                .send(),
1579        )
1580        .await
1581        .unwrap_err();
1582    }
1583
1584    #[tokio::test]
1585    async fn test_schema_for_empty_response() {
1586        let column_schemas = vec![
1587            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1588            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1589        ];
1590        let schema = Arc::new(Schema::new(column_schemas));
1591
1592        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1593        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1594
1595        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1596        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1597            let json_output = &json_resp.output[0];
1598            if let GreptimeQueryOutput::Records(r) = json_output {
1599                assert_eq!(r.num_rows(), 0);
1600                assert_eq!(r.num_cols(), 2);
1601                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1602                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1603            } else {
1604                panic!("invalid output type");
1605            }
1606        } else {
1607            panic!("invalid format")
1608        }
1609    }
1610
1611    #[tokio::test]
1612    async fn test_recordbatches_conversion() {
1613        let column_schemas = vec![
1614            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1615            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1616        ];
1617        let schema = Arc::new(Schema::new(column_schemas));
1618        let columns: Vec<VectorRef> = vec![
1619            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1620            Arc::new(StringVector::from(vec![
1621                None,
1622                Some("hello"),
1623                Some("greptime"),
1624                None,
1625            ])),
1626        ];
1627        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1628
1629        for format in [
1630            ResponseFormat::GreptimedbV1,
1631            ResponseFormat::InfluxdbV1,
1632            ResponseFormat::Csv(true, true),
1633            ResponseFormat::Table,
1634            ResponseFormat::Arrow,
1635            ResponseFormat::Json,
1636            ResponseFormat::Null,
1637        ] {
1638            let recordbatches =
1639                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1640            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1641            let json_resp = match format {
1642                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1643                ResponseFormat::Csv(with_names, with_types) => {
1644                    CsvResponse::from_output(outputs, with_names, with_types).await
1645                }
1646                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1647                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1648                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1649                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1650                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1651            };
1652
1653            match json_resp {
1654                HttpResponse::GreptimedbV1(resp) => {
1655                    let json_output = &resp.output[0];
1656                    if let GreptimeQueryOutput::Records(r) = json_output {
1657                        assert_eq!(r.num_rows(), 4);
1658                        assert_eq!(r.num_cols(), 2);
1659                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1660                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1661                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1662                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1663                    } else {
1664                        panic!("invalid output type");
1665                    }
1666                }
1667                HttpResponse::InfluxdbV1(resp) => {
1668                    let json_output = &resp.results()[0];
1669                    assert_eq!(json_output.num_rows(), 4);
1670                    assert_eq!(json_output.num_cols(), 2);
1671                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1672                    assert_eq!(
1673                        json_output.series[0].values[0][0],
1674                        serde_json::Value::from(1)
1675                    );
1676                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1677                }
1678                HttpResponse::Csv(resp) => {
1679                    let output = &resp.output()[0];
1680                    if let GreptimeQueryOutput::Records(r) = output {
1681                        assert_eq!(r.num_rows(), 4);
1682                        assert_eq!(r.num_cols(), 2);
1683                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1684                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1685                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1686                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1687                    } else {
1688                        panic!("invalid output type");
1689                    }
1690                }
1691
1692                HttpResponse::Table(resp) => {
1693                    let output = &resp.output()[0];
1694                    if let GreptimeQueryOutput::Records(r) = output {
1695                        assert_eq!(r.num_rows(), 4);
1696                        assert_eq!(r.num_cols(), 2);
1697                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1698                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1699                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1700                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1701                    } else {
1702                        panic!("invalid output type");
1703                    }
1704                }
1705
1706                HttpResponse::Arrow(resp) => {
1707                    let output = resp.data;
1708                    let mut reader = StreamReader::try_new(Cursor::new(output), None)
1709                        .expect("Arrow reader error");
1710                    let schema = reader.schema();
1711                    assert_eq!(schema.fields[0].name(), "numbers");
1712                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1713                    assert_eq!(schema.fields[1].name(), "strings");
1714                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1715
1716                    let rb = reader.next().unwrap().expect("read record batch failed");
1717                    assert_eq!(rb.num_columns(), 2);
1718                    assert_eq!(rb.num_rows(), 4);
1719                }
1720
1721                HttpResponse::Json(resp) => {
1722                    let output = &resp.output()[0];
1723                    if let GreptimeQueryOutput::Records(r) = output {
1724                        assert_eq!(r.num_rows(), 4);
1725                        assert_eq!(r.num_cols(), 2);
1726                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1727                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1728                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1729                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1730                    } else {
1731                        panic!("invalid output type");
1732                    }
1733                }
1734
1735                HttpResponse::Null(resp) => {
1736                    assert_eq!(resp.rows(), 4);
1737                }
1738
1739                HttpResponse::Error(err) => unreachable!("{err:?}"),
1740            }
1741        }
1742    }
1743
1744    #[test]
1745    fn test_response_format_misc() {
1746        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1747        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1748        assert_eq!(
1749            ResponseFormat::parse("csv"),
1750            Some(ResponseFormat::Csv(false, false))
1751        );
1752        assert_eq!(
1753            ResponseFormat::parse("csvwithnames"),
1754            Some(ResponseFormat::Csv(true, false))
1755        );
1756        assert_eq!(
1757            ResponseFormat::parse("csvwithnamesandtypes"),
1758            Some(ResponseFormat::Csv(true, true))
1759        );
1760        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1761        assert_eq!(
1762            ResponseFormat::parse("greptimedb_v1"),
1763            Some(ResponseFormat::GreptimedbV1)
1764        );
1765        assert_eq!(
1766            ResponseFormat::parse("influxdb_v1"),
1767            Some(ResponseFormat::InfluxdbV1)
1768        );
1769        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1770        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1771
1772        // invalid formats
1773        assert_eq!(ResponseFormat::parse("invalid"), None);
1774        assert_eq!(ResponseFormat::parse(""), None);
1775        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1776
1777        // as str
1778        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1779        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1780        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1781        assert_eq!(ResponseFormat::Table.as_str(), "table");
1782        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1783        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1784        assert_eq!(ResponseFormat::Json.as_str(), "json");
1785        assert_eq!(ResponseFormat::Null.as_str(), "null");
1786        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1787    }
1788
1789    #[test]
1790    fn test_decode_label_name_strict() {
1791        let strict = PromValidationMode::Strict;
1792
1793        // Valid Prometheus label names
1794        assert!(strict.decode_label_name(b"__name__").is_ok());
1795        assert!(strict.decode_label_name(b"job").is_ok());
1796        assert!(strict.decode_label_name(b"instance").is_ok());
1797        assert!(strict.decode_label_name(b"_private").is_ok());
1798        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1799        assert!(strict.decode_label_name(b"abc123").is_ok());
1800        assert!(strict.decode_label_name(b"A").is_ok());
1801        assert!(strict.decode_label_name(b"_").is_ok());
1802
1803        // Invalid: starts with digit
1804        assert!(strict.decode_label_name(b"0abc").is_err());
1805        assert!(strict.decode_label_name(b"123").is_err());
1806
1807        // Invalid: contains special characters
1808        assert!(strict.decode_label_name(b"label-name").is_err());
1809        assert!(strict.decode_label_name(b"label.name").is_err());
1810        assert!(strict.decode_label_name(b"label name").is_err());
1811        assert!(strict.decode_label_name(b"label/name").is_err());
1812
1813        // Invalid: empty
1814        assert!(strict.decode_label_name(b"").is_err());
1815
1816        // Invalid: non-ASCII UTF-8
1817        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1818
1819        // Invalid UTF-8 bytes should fail
1820        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1821    }
1822
1823    #[test]
1824    fn test_decode_label_name_lossy() {
1825        let lossy = PromValidationMode::Lossy;
1826
1827        // Label name validation is always enforced.
1828        assert!(lossy.decode_label_name(b"__name__").is_ok());
1829        assert!(lossy.decode_label_name(b"label-name").is_err());
1830        assert!(lossy.decode_label_name(b"0abc").is_err());
1831
1832        // Invalid UTF-8 bytes fail the label-name byte check.
1833        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1834    }
1835
1836    #[test]
1837    fn test_decode_label_name_unchecked() {
1838        let unchecked = PromValidationMode::Unchecked;
1839
1840        // Label name validation is always enforced.
1841        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1842        assert!(unchecked.decode_label_name(b"label-name").is_err());
1843        assert!(unchecked.decode_label_name(b"0abc").is_err());
1844    }
1845
1846    #[test]
1847    fn test_is_valid_prom_label_name_bytes() {
1848        assert!(validate_label_name(b"__name__"));
1849        assert!(validate_label_name(b"job"));
1850        assert!(validate_label_name(b"_"));
1851        assert!(validate_label_name(b"A"));
1852        assert!(validate_label_name(b"abc123"));
1853        assert!(validate_label_name(b"_leading_underscore"));
1854
1855        assert!(!validate_label_name(b""));
1856        assert!(!validate_label_name(b"0starts_with_digit"));
1857        assert!(!validate_label_name(b"has-dash"));
1858        assert!(!validate_label_name(b"has.dot"));
1859        assert!(!validate_label_name(b"has space"));
1860        assert!(!validate_label_name(&[0xff, 0xfe]));
1861    }
1862}