servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::Mutex as StdMutex;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{debug, error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tonic::codegen::Service;
48use tower::{Layer, ServiceBuilder};
49use tower_http::compression::CompressionLayer;
50use tower_http::cors::{AllowOrigin, Any, CorsLayer};
51use tower_http::decompression::RequestDecompressionLayer;
52use tower_http::trace::TraceLayer;
53
54use self::authorize::AuthState;
55use self::result::table_result::TableResponse;
56use crate::configurator::HttpConfiguratorRef;
57use crate::elasticsearch;
58use crate::error::{
59    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
60    OtherSnafu, Result,
61};
62use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
63use crate::http::otlp::OtlpState;
64use crate::http::prom_store::PromStoreState;
65use crate::http::prometheus::{
66    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
67    range_query, series_query,
68};
69use crate::http::result::arrow_result::ArrowResponse;
70use crate::http::result::csv_result::CsvResponse;
71use crate::http::result::error_result::ErrorResponse;
72use crate::http::result::greptime_result_v1::GreptimedbV1Response;
73use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
74use crate::http::result::json_result::JsonResponse;
75use crate::http::result::null_result::NullResponse;
76use crate::interceptor::LogIngestInterceptorRef;
77use crate::metrics::http_metrics_layer;
78use crate::metrics_handler::MetricsHandler;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_limiter::RequestMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod hints;
117mod read_preference;
118#[cfg(any(test, feature = "testing"))]
119pub mod test_helpers;
120
121pub const HTTP_API_VERSION: &str = "v1";
122pub const HTTP_API_PREFIX: &str = "/v1/";
123/// Default http body limit (64M).
124const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
125
126/// Authorization header
127pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
128
129// TODO(fys): This is a temporary workaround, it will be improved later
130pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
131
132#[derive(Default)]
133pub struct HttpServer {
134    router: StdMutex<Router>,
135    shutdown_tx: Mutex<Option<Sender<()>>>,
136    user_provider: Option<UserProviderRef>,
137    memory_limiter: RequestMemoryLimiter,
138
139    // plugins
140    plugins: Plugins,
141
142    // server configs
143    options: HttpOptions,
144    bind_addr: Option<SocketAddr>,
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(default)]
149pub struct HttpOptions {
150    pub addr: String,
151
152    #[serde(with = "humantime_serde")]
153    pub timeout: Duration,
154
155    #[serde(skip)]
156    pub disable_dashboard: bool,
157
158    pub body_limit: ReadableSize,
159
160    /// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit.
161    pub max_total_body_memory: ReadableSize,
162
163    /// Validation mode while decoding Prometheus remote write requests.
164    pub prom_validation_mode: PromValidationMode,
165
166    pub cors_allowed_origins: Vec<String>,
167
168    pub enable_cors: bool,
169}
170
171#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
172#[serde(rename_all = "snake_case")]
173pub enum PromValidationMode {
174    /// Force UTF8 validation
175    Strict,
176    /// Allow lossy UTF8 strings
177    Lossy,
178    /// Do not validate UTF8 strings.
179    Unchecked,
180}
181
182impl PromValidationMode {
183    /// Decodes provided bytes to [String] with optional UTF-8 validation.
184    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
185        let result = match self {
186            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
187                Ok(s) => s,
188                Err(e) => {
189                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
190                    return Err(DecodeError::new("invalid utf-8"));
191                }
192            },
193            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
194            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
195        };
196        Ok(result)
197    }
198}
199
200impl Default for HttpOptions {
201    fn default() -> Self {
202        Self {
203            addr: "127.0.0.1:4000".to_string(),
204            timeout: Duration::from_secs(0),
205            disable_dashboard: false,
206            body_limit: DEFAULT_BODY_LIMIT,
207            max_total_body_memory: ReadableSize(0),
208            cors_allowed_origins: Vec::new(),
209            enable_cors: true,
210            prom_validation_mode: PromValidationMode::Strict,
211        }
212    }
213}
214
215#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
216pub struct ColumnSchema {
217    name: String,
218    data_type: String,
219}
220
221impl ColumnSchema {
222    pub fn new(name: String, data_type: String) -> ColumnSchema {
223        ColumnSchema { name, data_type }
224    }
225}
226
227#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
228pub struct OutputSchema {
229    column_schemas: Vec<ColumnSchema>,
230}
231
232impl OutputSchema {
233    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
234        OutputSchema {
235            column_schemas: columns,
236        }
237    }
238}
239
240impl From<SchemaRef> for OutputSchema {
241    fn from(schema: SchemaRef) -> OutputSchema {
242        OutputSchema {
243            column_schemas: schema
244                .column_schemas()
245                .iter()
246                .map(|cs| ColumnSchema {
247                    name: cs.name.clone(),
248                    data_type: cs.data_type.name(),
249                })
250                .collect(),
251        }
252    }
253}
254
255#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
256pub struct HttpRecordsOutput {
257    schema: OutputSchema,
258    rows: Vec<Vec<Value>>,
259    // total_rows is equal to rows.len() in most cases,
260    // the Dashboard query result may be truncated, so we need to return the total_rows.
261    #[serde(default)]
262    total_rows: usize,
263
264    // plan level execution metrics
265    #[serde(skip_serializing_if = "HashMap::is_empty")]
266    #[serde(default)]
267    metrics: HashMap<String, Value>,
268}
269
270impl HttpRecordsOutput {
271    pub fn num_rows(&self) -> usize {
272        self.rows.len()
273    }
274
275    pub fn num_cols(&self) -> usize {
276        self.schema.column_schemas.len()
277    }
278
279    pub fn schema(&self) -> &OutputSchema {
280        &self.schema
281    }
282
283    pub fn rows(&self) -> &Vec<Vec<Value>> {
284        &self.rows
285    }
286}
287
288impl HttpRecordsOutput {
289    pub fn try_new(
290        schema: SchemaRef,
291        recordbatches: Vec<RecordBatch>,
292    ) -> std::result::Result<HttpRecordsOutput, Error> {
293        if recordbatches.is_empty() {
294            Ok(HttpRecordsOutput {
295                schema: OutputSchema::from(schema),
296                rows: vec![],
297                total_rows: 0,
298                metrics: Default::default(),
299            })
300        } else {
301            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
302            let mut rows = Vec::with_capacity(num_rows);
303
304            for recordbatch in recordbatches {
305                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
306                writer.write(recordbatch, &mut rows)?;
307            }
308
309            Ok(HttpRecordsOutput {
310                schema: OutputSchema::from(schema),
311                total_rows: rows.len(),
312                rows,
313                metrics: Default::default(),
314            })
315        }
316    }
317}
318
319#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
320#[serde(rename_all = "lowercase")]
321pub enum GreptimeQueryOutput {
322    AffectedRows(usize),
323    Records(HttpRecordsOutput),
324}
325
326/// It allows the results of SQL queries to be presented in different formats.
327#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
328pub enum ResponseFormat {
329    Arrow,
330    // (with_names, with_types)
331    Csv(bool, bool),
332    Table,
333    #[default]
334    GreptimedbV1,
335    InfluxdbV1,
336    Json,
337    Null,
338}
339
340impl ResponseFormat {
341    pub fn parse(s: &str) -> Option<Self> {
342        match s {
343            "arrow" => Some(ResponseFormat::Arrow),
344            "csv" => Some(ResponseFormat::Csv(false, false)),
345            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
346            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
347            "table" => Some(ResponseFormat::Table),
348            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
349            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
350            "json" => Some(ResponseFormat::Json),
351            "null" => Some(ResponseFormat::Null),
352            _ => None,
353        }
354    }
355
356    pub fn as_str(&self) -> &'static str {
357        match self {
358            ResponseFormat::Arrow => "arrow",
359            ResponseFormat::Csv(_, _) => "csv",
360            ResponseFormat::Table => "table",
361            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
362            ResponseFormat::InfluxdbV1 => "influxdb_v1",
363            ResponseFormat::Json => "json",
364            ResponseFormat::Null => "null",
365        }
366    }
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq)]
370pub enum Epoch {
371    Nanosecond,
372    Microsecond,
373    Millisecond,
374    Second,
375}
376
377impl Epoch {
378    pub fn parse(s: &str) -> Option<Epoch> {
379        // Both u and µ indicate microseconds.
380        // epoch = [ns,u,µ,ms,s],
381        // For details, see the Influxdb documents.
382        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
383        match s {
384            "ns" => Some(Epoch::Nanosecond),
385            "u" | "µ" => Some(Epoch::Microsecond),
386            "ms" => Some(Epoch::Millisecond),
387            "s" => Some(Epoch::Second),
388            _ => None, // just returns None for other cases
389        }
390    }
391
392    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
393        match self {
394            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
395            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
396            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
397            Epoch::Second => ts.convert_to(TimeUnit::Second),
398        }
399    }
400}
401
402impl Display for Epoch {
403    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        match self {
405            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
406            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
407            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
408            Epoch::Second => write!(f, "Epoch::Second"),
409        }
410    }
411}
412
413#[derive(Serialize, Deserialize, Debug)]
414pub enum HttpResponse {
415    Arrow(ArrowResponse),
416    Csv(CsvResponse),
417    Table(TableResponse),
418    Error(ErrorResponse),
419    GreptimedbV1(GreptimedbV1Response),
420    InfluxdbV1(InfluxdbV1Response),
421    Json(JsonResponse),
422    Null(NullResponse),
423}
424
425impl HttpResponse {
426    pub fn with_execution_time(self, execution_time: u64) -> Self {
427        match self {
428            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
429            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
430            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
431            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
432            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
433            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
434            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
435            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
436        }
437    }
438
439    pub fn with_limit(self, limit: usize) -> Self {
440        match self {
441            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
442            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
443            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
444            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
445            _ => self,
446        }
447    }
448}
449
450pub fn process_with_limit(
451    mut outputs: Vec<GreptimeQueryOutput>,
452    limit: usize,
453) -> Vec<GreptimeQueryOutput> {
454    outputs
455        .drain(..)
456        .map(|data| match data {
457            GreptimeQueryOutput::Records(mut records) => {
458                if records.rows.len() > limit {
459                    records.rows.truncate(limit);
460                    records.total_rows = limit;
461                }
462                GreptimeQueryOutput::Records(records)
463            }
464            _ => data,
465        })
466        .collect()
467}
468
469impl IntoResponse for HttpResponse {
470    fn into_response(self) -> Response {
471        match self {
472            HttpResponse::Arrow(resp) => resp.into_response(),
473            HttpResponse::Csv(resp) => resp.into_response(),
474            HttpResponse::Table(resp) => resp.into_response(),
475            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
476            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
477            HttpResponse::Json(resp) => resp.into_response(),
478            HttpResponse::Null(resp) => resp.into_response(),
479            HttpResponse::Error(resp) => resp.into_response(),
480        }
481    }
482}
483
484impl From<ArrowResponse> for HttpResponse {
485    fn from(value: ArrowResponse) -> Self {
486        HttpResponse::Arrow(value)
487    }
488}
489
490impl From<CsvResponse> for HttpResponse {
491    fn from(value: CsvResponse) -> Self {
492        HttpResponse::Csv(value)
493    }
494}
495
496impl From<TableResponse> for HttpResponse {
497    fn from(value: TableResponse) -> Self {
498        HttpResponse::Table(value)
499    }
500}
501
502impl From<ErrorResponse> for HttpResponse {
503    fn from(value: ErrorResponse) -> Self {
504        HttpResponse::Error(value)
505    }
506}
507
508impl From<GreptimedbV1Response> for HttpResponse {
509    fn from(value: GreptimedbV1Response) -> Self {
510        HttpResponse::GreptimedbV1(value)
511    }
512}
513
514impl From<InfluxdbV1Response> for HttpResponse {
515    fn from(value: InfluxdbV1Response) -> Self {
516        HttpResponse::InfluxdbV1(value)
517    }
518}
519
520impl From<JsonResponse> for HttpResponse {
521    fn from(value: JsonResponse) -> Self {
522        HttpResponse::Json(value)
523    }
524}
525
526impl From<NullResponse> for HttpResponse {
527    fn from(value: NullResponse) -> Self {
528        HttpResponse::Null(value)
529    }
530}
531
532#[derive(Clone)]
533pub struct ApiState {
534    pub sql_handler: ServerSqlQueryHandlerRef,
535}
536
537#[derive(Clone)]
538pub struct GreptimeOptionsConfigState {
539    pub greptime_config_options: String,
540}
541
542#[derive(Default)]
543pub struct HttpServerBuilder {
544    options: HttpOptions,
545    plugins: Plugins,
546    user_provider: Option<UserProviderRef>,
547    router: Router,
548}
549
550impl HttpServerBuilder {
551    pub fn new(options: HttpOptions) -> Self {
552        Self {
553            options,
554            plugins: Plugins::default(),
555            user_provider: None,
556            router: Router::new(),
557        }
558    }
559
560    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
561        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
562
563        Self {
564            router: self
565                .router
566                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
567            ..self
568        }
569    }
570
571    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
572        let logs_router = HttpServer::route_logs(logs_handler);
573
574        Self {
575            router: self
576                .router
577                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
578            ..self
579        }
580    }
581
582    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
583        Self {
584            router: self.router.nest(
585                &format!("/{HTTP_API_VERSION}/opentsdb"),
586                HttpServer::route_opentsdb(handler),
587            ),
588            ..self
589        }
590    }
591
592    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
593        Self {
594            router: self.router.nest(
595                &format!("/{HTTP_API_VERSION}/influxdb"),
596                HttpServer::route_influxdb(handler),
597            ),
598            ..self
599        }
600    }
601
602    pub fn with_prom_handler(
603        self,
604        handler: PromStoreProtocolHandlerRef,
605        pipeline_handler: Option<PipelineHandlerRef>,
606        prom_store_with_metric_engine: bool,
607        prom_validation_mode: PromValidationMode,
608    ) -> Self {
609        let state = PromStoreState {
610            prom_store_handler: handler,
611            pipeline_handler,
612            prom_store_with_metric_engine,
613            prom_validation_mode,
614        };
615
616        Self {
617            router: self.router.nest(
618                &format!("/{HTTP_API_VERSION}/prometheus"),
619                HttpServer::route_prom(state),
620            ),
621            ..self
622        }
623    }
624
625    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
626        Self {
627            router: self.router.nest(
628                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
629                HttpServer::route_prometheus(handler),
630            ),
631            ..self
632        }
633    }
634
635    pub fn with_otlp_handler(
636        self,
637        handler: OpenTelemetryProtocolHandlerRef,
638        with_metric_engine: bool,
639    ) -> Self {
640        Self {
641            router: self.router.nest(
642                &format!("/{HTTP_API_VERSION}/otlp"),
643                HttpServer::route_otlp(handler, with_metric_engine),
644            ),
645            ..self
646        }
647    }
648
649    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
650        Self {
651            user_provider: Some(user_provider),
652            ..self
653        }
654    }
655
656    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
657        Self {
658            router: self.router.merge(HttpServer::route_metrics(handler)),
659            ..self
660        }
661    }
662
663    pub fn with_log_ingest_handler(
664        self,
665        handler: PipelineHandlerRef,
666        validator: Option<LogValidatorRef>,
667        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
668    ) -> Self {
669        let log_state = LogState {
670            log_handler: handler,
671            log_validator: validator,
672            ingest_interceptor,
673        };
674
675        let router = self.router.nest(
676            &format!("/{HTTP_API_VERSION}"),
677            HttpServer::route_pipelines(log_state.clone()),
678        );
679        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
680        let router = router.nest(
681            &format!("/{HTTP_API_VERSION}/events"),
682            #[allow(deprecated)]
683            HttpServer::route_log_deprecated(log_state.clone()),
684        );
685
686        let router = router.nest(
687            &format!("/{HTTP_API_VERSION}/loki"),
688            HttpServer::route_loki(log_state.clone()),
689        );
690
691        let router = router.nest(
692            &format!("/{HTTP_API_VERSION}/elasticsearch"),
693            HttpServer::route_elasticsearch(log_state.clone()),
694        );
695
696        let router = router.nest(
697            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
698            Router::new()
699                .route("/", routing::get(elasticsearch::handle_get_version))
700                .with_state(log_state),
701        );
702
703        Self { router, ..self }
704    }
705
706    pub fn with_plugins(self, plugins: Plugins) -> Self {
707        Self { plugins, ..self }
708    }
709
710    pub fn with_greptime_config_options(self, opts: String) -> Self {
711        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
712            greptime_config_options: opts,
713        });
714
715        Self {
716            router: self.router.merge(config_router),
717            ..self
718        }
719    }
720
721    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
722        Self {
723            router: self.router.nest(
724                &format!("/{HTTP_API_VERSION}/jaeger"),
725                HttpServer::route_jaeger(handler),
726            ),
727            ..self
728        }
729    }
730
731    pub fn with_extra_router(self, router: Router) -> Self {
732        Self {
733            router: self.router.merge(router),
734            ..self
735        }
736    }
737
738    pub fn add_layer<L>(self, layer: L) -> Self
739    where
740        L: Layer<Route> + Clone + Send + Sync + 'static,
741        L::Service: Service<Request> + Clone + Send + Sync + 'static,
742        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
743        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
744        <L::Service as Service<Request>>::Future: Send + 'static,
745    {
746        Self {
747            router: self.router.layer(layer),
748            ..self
749        }
750    }
751
752    pub fn build(self) -> HttpServer {
753        let memory_limiter =
754            RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize);
755        HttpServer {
756            options: self.options,
757            user_provider: self.user_provider,
758            shutdown_tx: Mutex::new(None),
759            plugins: self.plugins,
760            router: StdMutex::new(self.router),
761            bind_addr: None,
762            memory_limiter,
763        }
764    }
765}
766
767impl HttpServer {
768    /// Gets the router and adds necessary root routes (health, status, dashboard).
769    pub fn make_app(&self) -> Router {
770        let mut router = {
771            let router = self.router.lock().unwrap();
772            router.clone()
773        };
774
775        router = router
776            .route(
777                "/health",
778                routing::get(handler::health).post(handler::health),
779            )
780            .route(
781                &format!("/{HTTP_API_VERSION}/health"),
782                routing::get(handler::health).post(handler::health),
783            )
784            .route(
785                "/ready",
786                routing::get(handler::health).post(handler::health),
787            );
788
789        router = router.route("/status", routing::get(handler::status));
790
791        #[cfg(feature = "dashboard")]
792        {
793            if !self.options.disable_dashboard {
794                info!("Enable dashboard service at '/dashboard'");
795                // redirect /dashboard to /dashboard/
796                router = router.route(
797                    "/dashboard",
798                    routing::get(|uri: axum::http::uri::Uri| async move {
799                        let path = uri.path();
800                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
801
802                        let new_uri = format!("{}/{}", path, query);
803                        axum::response::Redirect::permanent(&new_uri)
804                    }),
805                );
806
807                // "/dashboard" and "/dashboard/" are two different paths in Axum.
808                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
809                // So we explicitly route "/dashboard/" here.
810                router = router
811                    .route(
812                        "/dashboard/",
813                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
814                    )
815                    .route(
816                        "/dashboard/{*x}",
817                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
818                    );
819            }
820        }
821
822        // Add a layer to collect HTTP metrics for axum.
823        router = router.route_layer(middleware::from_fn(http_metrics_layer));
824
825        router
826    }
827
828    /// Attaches middlewares and debug routes to the router.
829    /// Callers should call this method after [HttpServer::make_app()].
830    pub fn build(&self, router: Router) -> Result<Router> {
831        let timeout_layer = if self.options.timeout != Duration::default() {
832            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
833        } else {
834            info!("HTTP server timeout is disabled");
835            None
836        };
837        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
838            Some(
839                ServiceBuilder::new()
840                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
841            )
842        } else {
843            info!("HTTP server body limit is disabled");
844            None
845        };
846        let cors_layer = if self.options.enable_cors {
847            Some(
848                CorsLayer::new()
849                    .allow_methods([
850                        Method::GET,
851                        Method::POST,
852                        Method::PUT,
853                        Method::DELETE,
854                        Method::HEAD,
855                    ])
856                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
857                        AllowOrigin::from(Any)
858                    } else {
859                        AllowOrigin::from(
860                            self.options
861                                .cors_allowed_origins
862                                .iter()
863                                .map(|s| {
864                                    HeaderValue::from_str(s.as_str())
865                                        .context(InvalidHeaderValueSnafu)
866                                })
867                                .collect::<Result<Vec<HeaderValue>>>()?,
868                        )
869                    })
870                    .allow_headers(Any),
871            )
872        } else {
873            info!("HTTP server cross-origin is disabled");
874            None
875        };
876
877        Ok(router
878            // middlewares
879            .layer(
880                ServiceBuilder::new()
881                    // disable on failure tracing. because printing out isn't very helpful,
882                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
883                    .layer(TraceLayer::new_for_http().on_failure(()))
884                    .option_layer(cors_layer)
885                    .option_layer(timeout_layer)
886                    .option_layer(body_limit_layer)
887                    // memory limit layer - must be before body is consumed
888                    .layer(middleware::from_fn_with_state(
889                        self.memory_limiter.clone(),
890                        memory_limit::memory_limit_middleware,
891                    ))
892                    // auth layer
893                    .layer(middleware::from_fn_with_state(
894                        AuthState::new(self.user_provider.clone()),
895                        authorize::check_http_auth,
896                    ))
897                    .layer(middleware::from_fn(hints::extract_hints))
898                    .layer(middleware::from_fn(
899                        read_preference::extract_read_preference,
900                    )),
901            )
902            // Handlers for debug, we don't expect a timeout.
903            .nest(
904                "/debug",
905                Router::new()
906                    // handler for changing log level dynamically
907                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
908                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
909                    .nest(
910                        "/prof",
911                        Router::new()
912                            .route("/cpu", routing::post(pprof::pprof_handler))
913                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
914                            .route(
915                                "/mem/activate",
916                                routing::post(mem_prof::activate_heap_prof_handler),
917                            )
918                            .route(
919                                "/mem/deactivate",
920                                routing::post(mem_prof::deactivate_heap_prof_handler),
921                            )
922                            .route(
923                                "/mem/status",
924                                routing::get(mem_prof::heap_prof_status_handler),
925                            ) // jemalloc gdump flag status and toggle
926                            .route(
927                                "/mem/gdump",
928                                routing::get(mem_prof::gdump_status_handler)
929                                    .post(mem_prof::gdump_toggle_handler),
930                            ),
931                    ),
932            ))
933    }
934
935    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
936        Router::new()
937            .route("/metrics", routing::get(handler::metrics))
938            .with_state(metrics_handler)
939    }
940
941    fn route_loki<S>(log_state: LogState) -> Router<S> {
942        Router::new()
943            .route("/api/v1/push", routing::post(loki::loki_ingest))
944            .layer(
945                ServiceBuilder::new()
946                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
947            )
948            .with_state(log_state)
949    }
950
951    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
952        Router::new()
953            // Return fake responsefor HEAD '/' request.
954            .route(
955                "/",
956                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
957            )
958            // Return fake response for Elasticsearch version request.
959            .route("/", routing::get(elasticsearch::handle_get_version))
960            // Return fake response for Elasticsearch license request.
961            .route("/_license", routing::get(elasticsearch::handle_get_license))
962            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
963            .route(
964                "/{index}/_bulk",
965                routing::post(elasticsearch::handle_bulk_api_with_index),
966            )
967            // Return fake response for Elasticsearch ilm request.
968            .route(
969                "/_ilm/policy/{*path}",
970                routing::any((
971                    HttpStatusCode::OK,
972                    elasticsearch::elasticsearch_headers(),
973                    axum::Json(serde_json::json!({})),
974                )),
975            )
976            // Return fake response for Elasticsearch index template request.
977            .route(
978                "/_index_template/{*path}",
979                routing::any((
980                    HttpStatusCode::OK,
981                    elasticsearch::elasticsearch_headers(),
982                    axum::Json(serde_json::json!({})),
983                )),
984            )
985            // Return fake response for Elasticsearch ingest pipeline request.
986            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
987            .route(
988                "/_ingest/{*path}",
989                routing::any((
990                    HttpStatusCode::OK,
991                    elasticsearch::elasticsearch_headers(),
992                    axum::Json(serde_json::json!({})),
993                )),
994            )
995            // Return fake response for Elasticsearch nodes discovery request.
996            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
997            .route(
998                "/_nodes/{*path}",
999                routing::any((
1000                    HttpStatusCode::OK,
1001                    elasticsearch::elasticsearch_headers(),
1002                    axum::Json(serde_json::json!({})),
1003                )),
1004            )
1005            // Return fake response for Logstash APIs requests.
1006            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1007            .route(
1008                "/logstash/{*path}",
1009                routing::any((
1010                    HttpStatusCode::OK,
1011                    elasticsearch::elasticsearch_headers(),
1012                    axum::Json(serde_json::json!({})),
1013                )),
1014            )
1015            .route(
1016                "/_logstash/{*path}",
1017                routing::any((
1018                    HttpStatusCode::OK,
1019                    elasticsearch::elasticsearch_headers(),
1020                    axum::Json(serde_json::json!({})),
1021                )),
1022            )
1023            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1024            .with_state(log_state)
1025    }
1026
1027    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1028    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1029        Router::new()
1030            .route("/logs", routing::post(event::log_ingester))
1031            .route(
1032                "/pipelines/{pipeline_name}",
1033                routing::get(event::query_pipeline),
1034            )
1035            .route(
1036                "/pipelines/{pipeline_name}",
1037                routing::post(event::add_pipeline),
1038            )
1039            .route(
1040                "/pipelines/{pipeline_name}",
1041                routing::delete(event::delete_pipeline),
1042            )
1043            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1044            .layer(
1045                ServiceBuilder::new()
1046                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1047            )
1048            .with_state(log_state)
1049    }
1050
1051    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1052        Router::new()
1053            .route("/ingest", routing::post(event::log_ingester))
1054            .route(
1055                "/pipelines/{pipeline_name}",
1056                routing::get(event::query_pipeline),
1057            )
1058            .route(
1059                "/pipelines/{pipeline_name}/ddl",
1060                routing::get(event::query_pipeline_ddl),
1061            )
1062            .route(
1063                "/pipelines/{pipeline_name}",
1064                routing::post(event::add_pipeline),
1065            )
1066            .route(
1067                "/pipelines/{pipeline_name}",
1068                routing::delete(event::delete_pipeline),
1069            )
1070            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1071            .layer(
1072                ServiceBuilder::new()
1073                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1074            )
1075            .with_state(log_state)
1076    }
1077
1078    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1079        Router::new()
1080            .route("/sql", routing::get(handler::sql).post(handler::sql))
1081            .route(
1082                "/sql/parse",
1083                routing::get(handler::sql_parse).post(handler::sql_parse),
1084            )
1085            .route(
1086                "/sql/format",
1087                routing::get(handler::sql_format).post(handler::sql_format),
1088            )
1089            .route(
1090                "/promql",
1091                routing::get(handler::promql).post(handler::promql),
1092            )
1093            .with_state(api_state)
1094    }
1095
1096    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1097        Router::new()
1098            .route("/logs", routing::get(logs::logs).post(logs::logs))
1099            .with_state(log_handler)
1100    }
1101
1102    /// Route Prometheus [HTTP API].
1103    ///
1104    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1105    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1106        Router::new()
1107            .route(
1108                "/format_query",
1109                routing::post(format_query).get(format_query),
1110            )
1111            .route("/status/buildinfo", routing::get(build_info_query))
1112            .route("/query", routing::post(instant_query).get(instant_query))
1113            .route("/query_range", routing::post(range_query).get(range_query))
1114            .route("/labels", routing::post(labels_query).get(labels_query))
1115            .route("/series", routing::post(series_query).get(series_query))
1116            .route("/parse_query", routing::post(parse_query).get(parse_query))
1117            .route(
1118                "/label/{label_name}/values",
1119                routing::get(label_values_query),
1120            )
1121            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1122            .with_state(prometheus_handler)
1123    }
1124
1125    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1126    /// called `prom_store`.
1127    ///
1128    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1129    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1130    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1131        Router::new()
1132            .route("/read", routing::post(prom_store::remote_read))
1133            .route("/write", routing::post(prom_store::remote_write))
1134            .with_state(state)
1135    }
1136
1137    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1138        Router::new()
1139            .route("/write", routing::post(influxdb_write_v1))
1140            .route("/api/v2/write", routing::post(influxdb_write_v2))
1141            .layer(
1142                ServiceBuilder::new()
1143                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1144            )
1145            .route("/ping", routing::get(influxdb_ping))
1146            .route("/health", routing::get(influxdb_health))
1147            .with_state(influxdb_handler)
1148    }
1149
1150    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1151        Router::new()
1152            .route("/api/put", routing::post(opentsdb::put))
1153            .with_state(opentsdb_handler)
1154    }
1155
1156    fn route_otlp<S>(
1157        otlp_handler: OpenTelemetryProtocolHandlerRef,
1158        with_metric_engine: bool,
1159    ) -> Router<S> {
1160        Router::new()
1161            .route("/v1/metrics", routing::post(otlp::metrics))
1162            .route("/v1/traces", routing::post(otlp::traces))
1163            .route("/v1/logs", routing::post(otlp::logs))
1164            .layer(
1165                ServiceBuilder::new()
1166                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1167            )
1168            .with_state(OtlpState {
1169                with_metric_engine,
1170                handler: otlp_handler,
1171            })
1172    }
1173
1174    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1175        Router::new()
1176            .route("/config", routing::get(handler::config))
1177            .with_state(state)
1178    }
1179
1180    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1181        Router::new()
1182            .route("/api/services", routing::get(jaeger::handle_get_services))
1183            .route(
1184                "/api/services/{service_name}/operations",
1185                routing::get(jaeger::handle_get_operations_by_service),
1186            )
1187            .route(
1188                "/api/operations",
1189                routing::get(jaeger::handle_get_operations),
1190            )
1191            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1192            .route(
1193                "/api/traces/{trace_id}",
1194                routing::get(jaeger::handle_get_trace),
1195            )
1196            .with_state(handler)
1197    }
1198}
1199
1200pub const HTTP_SERVER: &str = "HTTP_SERVER";
1201
1202#[async_trait]
1203impl Server for HttpServer {
1204    async fn shutdown(&self) -> Result<()> {
1205        let mut shutdown_tx = self.shutdown_tx.lock().await;
1206        if let Some(tx) = shutdown_tx.take()
1207            && tx.send(()).is_err()
1208        {
1209            info!("Receiver dropped, the HTTP server has already exited");
1210        }
1211        info!("Shutdown HTTP server");
1212
1213        Ok(())
1214    }
1215
1216    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1217        let (tx, rx) = oneshot::channel();
1218        let serve = {
1219            let mut shutdown_tx = self.shutdown_tx.lock().await;
1220            ensure!(
1221                shutdown_tx.is_none(),
1222                AlreadyStartedSnafu { server: "HTTP" }
1223            );
1224
1225            let mut app = self.make_app();
1226            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1227                app = configurator
1228                    .configure_http(app, ())
1229                    .await
1230                    .context(OtherSnafu)?;
1231            }
1232            let app = self.build(app)?;
1233            let listener = tokio::net::TcpListener::bind(listening)
1234                .await
1235                .context(AddressBindSnafu { addr: listening })?
1236                .tap_io(|tcp_stream| {
1237                    if let Err(e) = tcp_stream.set_nodelay(true) {
1238                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1239                    }
1240                });
1241            let serve = axum::serve(listener, app.into_make_service());
1242
1243            // FIXME(yingwen): Support keepalive.
1244            // See:
1245            // - https://github.com/tokio-rs/axum/discussions/2939
1246            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1247            // let server = axum::Server::try_bind(&listening)
1248            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1249            //     .tcp_nodelay(true)
1250            //     // Enable TCP keepalive to close the dangling established connections.
1251            //     // It's configured to let the keepalive probes first send after the connection sits
1252            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1253            //     // So the connection will be closed after roughly 1 hour.
1254            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1255            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1256            //     .tcp_keepalive_retries(Some(6))
1257            //     .serve(app.into_make_service());
1258
1259            *shutdown_tx = Some(tx);
1260
1261            serve
1262        };
1263        let listening = serve.local_addr().context(InternalIoSnafu)?;
1264        info!("HTTP server is bound to {}", listening);
1265
1266        common_runtime::spawn_global(async move {
1267            if let Err(e) = serve
1268                .with_graceful_shutdown(rx.map(drop))
1269                .await
1270                .context(InternalIoSnafu)
1271            {
1272                error!(e; "Failed to shutdown http server");
1273            }
1274        });
1275
1276        self.bind_addr = Some(listening);
1277        Ok(())
1278    }
1279
1280    fn name(&self) -> &str {
1281        HTTP_SERVER
1282    }
1283
1284    fn bind_addr(&self) -> Option<SocketAddr> {
1285        self.bind_addr
1286    }
1287}
1288
1289#[cfg(test)]
1290mod test {
1291    use std::future::pending;
1292    use std::io::Cursor;
1293    use std::sync::Arc;
1294
1295    use arrow_ipc::reader::FileReader;
1296    use arrow_schema::DataType;
1297    use axum::handler::Handler;
1298    use axum::http::StatusCode;
1299    use axum::routing::get;
1300    use common_query::Output;
1301    use common_recordbatch::RecordBatches;
1302    use datafusion_expr::LogicalPlan;
1303    use datatypes::prelude::*;
1304    use datatypes::schema::{ColumnSchema, Schema};
1305    use datatypes::vectors::{StringVector, UInt32Vector};
1306    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1307    use query::parser::PromQuery;
1308    use query::query_engine::DescribeResult;
1309    use session::context::QueryContextRef;
1310    use sql::statements::statement::Statement;
1311    use tokio::sync::mpsc;
1312    use tokio::time::Instant;
1313
1314    use super::*;
1315    use crate::error::Error;
1316    use crate::http::test_helpers::TestClient;
1317    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1318
1319    struct DummyInstance {
1320        _tx: mpsc::Sender<(String, Vec<u8>)>,
1321    }
1322
1323    #[async_trait]
1324    impl SqlQueryHandler for DummyInstance {
1325        type Error = Error;
1326
1327        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1328            unimplemented!()
1329        }
1330
1331        async fn do_promql_query(
1332            &self,
1333            _: &PromQuery,
1334            _: QueryContextRef,
1335        ) -> Vec<std::result::Result<Output, Self::Error>> {
1336            unimplemented!()
1337        }
1338
1339        async fn do_exec_plan(
1340            &self,
1341            _stmt: Option<Statement>,
1342            _plan: LogicalPlan,
1343            _query_ctx: QueryContextRef,
1344        ) -> std::result::Result<Output, Self::Error> {
1345            unimplemented!()
1346        }
1347
1348        async fn do_describe(
1349            &self,
1350            _stmt: sql::statements::statement::Statement,
1351            _query_ctx: QueryContextRef,
1352        ) -> Result<Option<DescribeResult>> {
1353            unimplemented!()
1354        }
1355
1356        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1357            Ok(true)
1358        }
1359    }
1360
1361    fn timeout() -> DynamicTimeoutLayer {
1362        DynamicTimeoutLayer::new(Duration::from_millis(10))
1363    }
1364
1365    async fn forever() {
1366        pending().await
1367    }
1368
1369    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1370        make_test_app_custom(tx, HttpOptions::default())
1371    }
1372
1373    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1374        let instance = Arc::new(DummyInstance { _tx: tx });
1375        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1376        let server = HttpServerBuilder::new(options)
1377            .with_sql_handler(sql_instance)
1378            .build();
1379        server.build(server.make_app()).unwrap().route(
1380            "/test/timeout",
1381            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1382        )
1383    }
1384
1385    #[tokio::test]
1386    pub async fn test_cors() {
1387        // cors is on by default
1388        let (tx, _rx) = mpsc::channel(100);
1389        let app = make_test_app(tx);
1390        let client = TestClient::new(app).await;
1391
1392        let res = client.get("/health").send().await;
1393
1394        assert_eq!(res.status(), StatusCode::OK);
1395        assert_eq!(
1396            res.headers()
1397                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1398                .expect("expect cors header origin"),
1399            "*"
1400        );
1401
1402        let res = client.get("/v1/health").send().await;
1403
1404        assert_eq!(res.status(), StatusCode::OK);
1405        assert_eq!(
1406            res.headers()
1407                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1408                .expect("expect cors header origin"),
1409            "*"
1410        );
1411
1412        let res = client
1413            .options("/health")
1414            .header("Access-Control-Request-Headers", "x-greptime-auth")
1415            .header("Access-Control-Request-Method", "DELETE")
1416            .header("Origin", "https://example.com")
1417            .send()
1418            .await;
1419        assert_eq!(res.status(), StatusCode::OK);
1420        assert_eq!(
1421            res.headers()
1422                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1423                .expect("expect cors header origin"),
1424            "*"
1425        );
1426        assert_eq!(
1427            res.headers()
1428                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1429                .expect("expect cors header headers"),
1430            "*"
1431        );
1432        assert_eq!(
1433            res.headers()
1434                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1435                .expect("expect cors header methods"),
1436            "GET,POST,PUT,DELETE,HEAD"
1437        );
1438    }
1439
1440    #[tokio::test]
1441    pub async fn test_cors_custom_origins() {
1442        // cors is on by default
1443        let (tx, _rx) = mpsc::channel(100);
1444        let origin = "https://example.com";
1445
1446        let options = HttpOptions {
1447            cors_allowed_origins: vec![origin.to_string()],
1448            ..Default::default()
1449        };
1450
1451        let app = make_test_app_custom(tx, options);
1452        let client = TestClient::new(app).await;
1453
1454        let res = client.get("/health").header("Origin", origin).send().await;
1455
1456        assert_eq!(res.status(), StatusCode::OK);
1457        assert_eq!(
1458            res.headers()
1459                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1460                .expect("expect cors header origin"),
1461            origin
1462        );
1463
1464        let res = client
1465            .get("/health")
1466            .header("Origin", "https://notallowed.com")
1467            .send()
1468            .await;
1469
1470        assert_eq!(res.status(), StatusCode::OK);
1471        assert!(
1472            !res.headers()
1473                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1474        );
1475    }
1476
1477    #[tokio::test]
1478    pub async fn test_cors_disabled() {
1479        // cors is on by default
1480        let (tx, _rx) = mpsc::channel(100);
1481
1482        let options = HttpOptions {
1483            enable_cors: false,
1484            ..Default::default()
1485        };
1486
1487        let app = make_test_app_custom(tx, options);
1488        let client = TestClient::new(app).await;
1489
1490        let res = client.get("/health").send().await;
1491
1492        assert_eq!(res.status(), StatusCode::OK);
1493        assert!(
1494            !res.headers()
1495                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1496        );
1497    }
1498
1499    #[test]
1500    fn test_http_options_default() {
1501        let default = HttpOptions::default();
1502        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1503        assert_eq!(Duration::from_secs(0), default.timeout)
1504    }
1505
1506    #[tokio::test]
1507    async fn test_http_server_request_timeout() {
1508        common_telemetry::init_default_ut_logging();
1509
1510        let (tx, _rx) = mpsc::channel(100);
1511        let app = make_test_app(tx);
1512        let client = TestClient::new(app).await;
1513        let res = client.get("/test/timeout").send().await;
1514        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1515
1516        let now = Instant::now();
1517        let res = client
1518            .get("/test/timeout")
1519            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1520            .send()
1521            .await;
1522        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1523        let elapsed = now.elapsed();
1524        assert!(elapsed > Duration::from_millis(15));
1525
1526        tokio::time::timeout(
1527            Duration::from_millis(15),
1528            client
1529                .get("/test/timeout")
1530                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1531                .send(),
1532        )
1533        .await
1534        .unwrap_err();
1535
1536        tokio::time::timeout(
1537            Duration::from_millis(15),
1538            client
1539                .get("/test/timeout")
1540                .header(
1541                    GREPTIME_DB_HEADER_TIMEOUT,
1542                    humantime::format_duration(Duration::default()).to_string(),
1543                )
1544                .send(),
1545        )
1546        .await
1547        .unwrap_err();
1548    }
1549
1550    #[tokio::test]
1551    async fn test_schema_for_empty_response() {
1552        let column_schemas = vec![
1553            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1554            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1555        ];
1556        let schema = Arc::new(Schema::new(column_schemas));
1557
1558        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1559        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1560
1561        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1562        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1563            let json_output = &json_resp.output[0];
1564            if let GreptimeQueryOutput::Records(r) = json_output {
1565                assert_eq!(r.num_rows(), 0);
1566                assert_eq!(r.num_cols(), 2);
1567                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1568                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1569            } else {
1570                panic!("invalid output type");
1571            }
1572        } else {
1573            panic!("invalid format")
1574        }
1575    }
1576
1577    #[tokio::test]
1578    async fn test_recordbatches_conversion() {
1579        let column_schemas = vec![
1580            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1581            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1582        ];
1583        let schema = Arc::new(Schema::new(column_schemas));
1584        let columns: Vec<VectorRef> = vec![
1585            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1586            Arc::new(StringVector::from(vec![
1587                None,
1588                Some("hello"),
1589                Some("greptime"),
1590                None,
1591            ])),
1592        ];
1593        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1594
1595        for format in [
1596            ResponseFormat::GreptimedbV1,
1597            ResponseFormat::InfluxdbV1,
1598            ResponseFormat::Csv(true, true),
1599            ResponseFormat::Table,
1600            ResponseFormat::Arrow,
1601            ResponseFormat::Json,
1602            ResponseFormat::Null,
1603        ] {
1604            let recordbatches =
1605                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1606            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1607            let json_resp = match format {
1608                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1609                ResponseFormat::Csv(with_names, with_types) => {
1610                    CsvResponse::from_output(outputs, with_names, with_types).await
1611                }
1612                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1613                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1614                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1615                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1616                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1617            };
1618
1619            match json_resp {
1620                HttpResponse::GreptimedbV1(resp) => {
1621                    let json_output = &resp.output[0];
1622                    if let GreptimeQueryOutput::Records(r) = json_output {
1623                        assert_eq!(r.num_rows(), 4);
1624                        assert_eq!(r.num_cols(), 2);
1625                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1626                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1627                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1628                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1629                    } else {
1630                        panic!("invalid output type");
1631                    }
1632                }
1633                HttpResponse::InfluxdbV1(resp) => {
1634                    let json_output = &resp.results()[0];
1635                    assert_eq!(json_output.num_rows(), 4);
1636                    assert_eq!(json_output.num_cols(), 2);
1637                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1638                    assert_eq!(
1639                        json_output.series[0].values[0][0],
1640                        serde_json::Value::from(1)
1641                    );
1642                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1643                }
1644                HttpResponse::Csv(resp) => {
1645                    let output = &resp.output()[0];
1646                    if let GreptimeQueryOutput::Records(r) = output {
1647                        assert_eq!(r.num_rows(), 4);
1648                        assert_eq!(r.num_cols(), 2);
1649                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1650                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1651                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1652                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1653                    } else {
1654                        panic!("invalid output type");
1655                    }
1656                }
1657
1658                HttpResponse::Table(resp) => {
1659                    let output = &resp.output()[0];
1660                    if let GreptimeQueryOutput::Records(r) = output {
1661                        assert_eq!(r.num_rows(), 4);
1662                        assert_eq!(r.num_cols(), 2);
1663                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1664                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1665                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1666                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1667                    } else {
1668                        panic!("invalid output type");
1669                    }
1670                }
1671
1672                HttpResponse::Arrow(resp) => {
1673                    let output = resp.data;
1674                    let mut reader =
1675                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1676                    let schema = reader.schema();
1677                    assert_eq!(schema.fields[0].name(), "numbers");
1678                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1679                    assert_eq!(schema.fields[1].name(), "strings");
1680                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1681
1682                    let rb = reader.next().unwrap().expect("read record batch failed");
1683                    assert_eq!(rb.num_columns(), 2);
1684                    assert_eq!(rb.num_rows(), 4);
1685                }
1686
1687                HttpResponse::Json(resp) => {
1688                    let output = &resp.output()[0];
1689                    if let GreptimeQueryOutput::Records(r) = output {
1690                        assert_eq!(r.num_rows(), 4);
1691                        assert_eq!(r.num_cols(), 2);
1692                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1693                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1694                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1695                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1696                    } else {
1697                        panic!("invalid output type");
1698                    }
1699                }
1700
1701                HttpResponse::Null(resp) => {
1702                    assert_eq!(resp.rows(), 4);
1703                }
1704
1705                HttpResponse::Error(err) => unreachable!("{err:?}"),
1706            }
1707        }
1708    }
1709
1710    #[test]
1711    fn test_response_format_misc() {
1712        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1713        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1714        assert_eq!(
1715            ResponseFormat::parse("csv"),
1716            Some(ResponseFormat::Csv(false, false))
1717        );
1718        assert_eq!(
1719            ResponseFormat::parse("csvwithnames"),
1720            Some(ResponseFormat::Csv(true, false))
1721        );
1722        assert_eq!(
1723            ResponseFormat::parse("csvwithnamesandtypes"),
1724            Some(ResponseFormat::Csv(true, true))
1725        );
1726        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1727        assert_eq!(
1728            ResponseFormat::parse("greptimedb_v1"),
1729            Some(ResponseFormat::GreptimedbV1)
1730        );
1731        assert_eq!(
1732            ResponseFormat::parse("influxdb_v1"),
1733            Some(ResponseFormat::InfluxdbV1)
1734        );
1735        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1736        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1737
1738        // invalid formats
1739        assert_eq!(ResponseFormat::parse("invalid"), None);
1740        assert_eq!(ResponseFormat::parse(""), None);
1741        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1742
1743        // as str
1744        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1745        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1746        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1747        assert_eq!(ResponseFormat::Table.as_str(), "table");
1748        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1749        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1750        assert_eq!(ResponseFormat::Json.as_str(), "json");
1751        assert_eq!(ResponseFormat::Null.as_str(), "null");
1752        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1753    }
1754}