servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Display;
17use std::net::SocketAddr;
18use std::sync::Mutex as StdMutex;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use auth::UserProviderRef;
23use axum::extract::DefaultBodyLimit;
24use axum::http::StatusCode as HttpStatusCode;
25use axum::response::{IntoResponse, Response};
26use axum::serve::ListenerExt;
27use axum::{Router, middleware, routing};
28use common_base::Plugins;
29use common_base::readable_size::ReadableSize;
30use common_recordbatch::RecordBatch;
31use common_telemetry::{debug, error, info};
32use common_time::Timestamp;
33use common_time::timestamp::TimeUnit;
34use datatypes::data_type::DataType;
35use datatypes::schema::SchemaRef;
36use event::{LogState, LogValidatorRef};
37use futures::FutureExt;
38use http::{HeaderValue, Method};
39use prost::DecodeError;
40use serde::{Deserialize, Serialize};
41use serde_json::Value;
42use snafu::{ResultExt, ensure};
43use tokio::sync::Mutex;
44use tokio::sync::oneshot::{self, Sender};
45use tower::ServiceBuilder;
46use tower_http::compression::CompressionLayer;
47use tower_http::cors::{AllowOrigin, Any, CorsLayer};
48use tower_http::decompression::RequestDecompressionLayer;
49use tower_http::trace::TraceLayer;
50
51use self::authorize::AuthState;
52use self::result::table_result::TableResponse;
53use crate::configurator::HttpConfiguratorRef;
54use crate::elasticsearch;
55use crate::error::{
56    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
57    OtherSnafu, Result,
58};
59use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
60use crate::http::otlp::OtlpState;
61use crate::http::prom_store::PromStoreState;
62use crate::http::prometheus::{
63    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
64    range_query, series_query,
65};
66use crate::http::result::arrow_result::ArrowResponse;
67use crate::http::result::csv_result::CsvResponse;
68use crate::http::result::error_result::ErrorResponse;
69use crate::http::result::greptime_result_v1::GreptimedbV1Response;
70use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
71use crate::http::result::json_result::JsonResponse;
72use crate::http::result::null_result::NullResponse;
73use crate::interceptor::LogIngestInterceptorRef;
74use crate::metrics::http_metrics_layer;
75use crate::metrics_handler::MetricsHandler;
76use crate::prometheus_handler::PrometheusHandlerRef;
77use crate::query_handler::sql::ServerSqlQueryHandlerRef;
78use crate::query_handler::{
79    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
80    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
81    PromStoreProtocolHandlerRef,
82};
83use crate::request_limiter::RequestMemoryLimiter;
84use crate::server::Server;
85
86pub mod authorize;
87#[cfg(feature = "dashboard")]
88mod dashboard;
89pub mod dyn_log;
90pub mod dyn_trace;
91pub mod event;
92pub mod extractor;
93pub mod handler;
94pub mod header;
95pub mod influxdb;
96pub mod jaeger;
97pub mod logs;
98pub mod loki;
99pub mod mem_prof;
100mod memory_limit;
101pub mod opentsdb;
102pub mod otlp;
103pub mod pprof;
104pub mod prom_store;
105pub mod prometheus;
106pub mod result;
107mod timeout;
108pub mod utils;
109
110use result::HttpOutputWriter;
111pub(crate) use timeout::DynamicTimeoutLayer;
112
113mod hints;
114mod read_preference;
115#[cfg(any(test, feature = "testing"))]
116pub mod test_helpers;
117
118pub const HTTP_API_VERSION: &str = "v1";
119pub const HTTP_API_PREFIX: &str = "/v1/";
120/// Default http body limit (64M).
121const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
122
123/// Authorization header
124pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
125
126// TODO(fys): This is a temporary workaround, it will be improved later
127pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
128
129#[derive(Default)]
130pub struct HttpServer {
131    router: StdMutex<Router>,
132    shutdown_tx: Mutex<Option<Sender<()>>>,
133    user_provider: Option<UserProviderRef>,
134    memory_limiter: RequestMemoryLimiter,
135
136    // plugins
137    plugins: Plugins,
138
139    // server configs
140    options: HttpOptions,
141    bind_addr: Option<SocketAddr>,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145#[serde(default)]
146pub struct HttpOptions {
147    pub addr: String,
148
149    #[serde(with = "humantime_serde")]
150    pub timeout: Duration,
151
152    #[serde(skip)]
153    pub disable_dashboard: bool,
154
155    pub body_limit: ReadableSize,
156
157    /// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit.
158    pub max_total_body_memory: ReadableSize,
159
160    /// Validation mode while decoding Prometheus remote write requests.
161    pub prom_validation_mode: PromValidationMode,
162
163    pub cors_allowed_origins: Vec<String>,
164
165    pub enable_cors: bool,
166}
167
168#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum PromValidationMode {
171    /// Force UTF8 validation
172    Strict,
173    /// Allow lossy UTF8 strings
174    Lossy,
175    /// Do not validate UTF8 strings.
176    Unchecked,
177}
178
179impl PromValidationMode {
180    /// Decodes provided bytes to [String] with optional UTF-8 validation.
181    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
182        let result = match self {
183            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
184                Ok(s) => s,
185                Err(e) => {
186                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
187                    return Err(DecodeError::new("invalid utf-8"));
188                }
189            },
190            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
191            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
192        };
193        Ok(result)
194    }
195}
196
197impl Default for HttpOptions {
198    fn default() -> Self {
199        Self {
200            addr: "127.0.0.1:4000".to_string(),
201            timeout: Duration::from_secs(0),
202            disable_dashboard: false,
203            body_limit: DEFAULT_BODY_LIMIT,
204            max_total_body_memory: ReadableSize(0),
205            cors_allowed_origins: Vec::new(),
206            enable_cors: true,
207            prom_validation_mode: PromValidationMode::Strict,
208        }
209    }
210}
211
212#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
213pub struct ColumnSchema {
214    name: String,
215    data_type: String,
216}
217
218impl ColumnSchema {
219    pub fn new(name: String, data_type: String) -> ColumnSchema {
220        ColumnSchema { name, data_type }
221    }
222}
223
224#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
225pub struct OutputSchema {
226    column_schemas: Vec<ColumnSchema>,
227}
228
229impl OutputSchema {
230    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
231        OutputSchema {
232            column_schemas: columns,
233        }
234    }
235}
236
237impl From<SchemaRef> for OutputSchema {
238    fn from(schema: SchemaRef) -> OutputSchema {
239        OutputSchema {
240            column_schemas: schema
241                .column_schemas()
242                .iter()
243                .map(|cs| ColumnSchema {
244                    name: cs.name.clone(),
245                    data_type: cs.data_type.name(),
246                })
247                .collect(),
248        }
249    }
250}
251
252#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
253pub struct HttpRecordsOutput {
254    schema: OutputSchema,
255    rows: Vec<Vec<Value>>,
256    // total_rows is equal to rows.len() in most cases,
257    // the Dashboard query result may be truncated, so we need to return the total_rows.
258    #[serde(default)]
259    total_rows: usize,
260
261    // plan level execution metrics
262    #[serde(skip_serializing_if = "HashMap::is_empty")]
263    #[serde(default)]
264    metrics: HashMap<String, Value>,
265}
266
267impl HttpRecordsOutput {
268    pub fn num_rows(&self) -> usize {
269        self.rows.len()
270    }
271
272    pub fn num_cols(&self) -> usize {
273        self.schema.column_schemas.len()
274    }
275
276    pub fn schema(&self) -> &OutputSchema {
277        &self.schema
278    }
279
280    pub fn rows(&self) -> &Vec<Vec<Value>> {
281        &self.rows
282    }
283}
284
285impl HttpRecordsOutput {
286    pub fn try_new(
287        schema: SchemaRef,
288        recordbatches: Vec<RecordBatch>,
289    ) -> std::result::Result<HttpRecordsOutput, Error> {
290        if recordbatches.is_empty() {
291            Ok(HttpRecordsOutput {
292                schema: OutputSchema::from(schema),
293                rows: vec![],
294                total_rows: 0,
295                metrics: Default::default(),
296            })
297        } else {
298            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
299            let mut rows = Vec::with_capacity(num_rows);
300
301            for recordbatch in recordbatches {
302                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
303                writer.write(recordbatch, &mut rows)?;
304            }
305
306            Ok(HttpRecordsOutput {
307                schema: OutputSchema::from(schema),
308                total_rows: rows.len(),
309                rows,
310                metrics: Default::default(),
311            })
312        }
313    }
314}
315
316#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
317#[serde(rename_all = "lowercase")]
318pub enum GreptimeQueryOutput {
319    AffectedRows(usize),
320    Records(HttpRecordsOutput),
321}
322
323/// It allows the results of SQL queries to be presented in different formats.
324#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
325pub enum ResponseFormat {
326    Arrow,
327    // (with_names, with_types)
328    Csv(bool, bool),
329    Table,
330    #[default]
331    GreptimedbV1,
332    InfluxdbV1,
333    Json,
334    Null,
335}
336
337impl ResponseFormat {
338    pub fn parse(s: &str) -> Option<Self> {
339        match s {
340            "arrow" => Some(ResponseFormat::Arrow),
341            "csv" => Some(ResponseFormat::Csv(false, false)),
342            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
343            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
344            "table" => Some(ResponseFormat::Table),
345            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
346            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
347            "json" => Some(ResponseFormat::Json),
348            "null" => Some(ResponseFormat::Null),
349            _ => None,
350        }
351    }
352
353    pub fn as_str(&self) -> &'static str {
354        match self {
355            ResponseFormat::Arrow => "arrow",
356            ResponseFormat::Csv(_, _) => "csv",
357            ResponseFormat::Table => "table",
358            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
359            ResponseFormat::InfluxdbV1 => "influxdb_v1",
360            ResponseFormat::Json => "json",
361            ResponseFormat::Null => "null",
362        }
363    }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367pub enum Epoch {
368    Nanosecond,
369    Microsecond,
370    Millisecond,
371    Second,
372}
373
374impl Epoch {
375    pub fn parse(s: &str) -> Option<Epoch> {
376        // Both u and µ indicate microseconds.
377        // epoch = [ns,u,µ,ms,s],
378        // For details, see the Influxdb documents.
379        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
380        match s {
381            "ns" => Some(Epoch::Nanosecond),
382            "u" | "µ" => Some(Epoch::Microsecond),
383            "ms" => Some(Epoch::Millisecond),
384            "s" => Some(Epoch::Second),
385            _ => None, // just returns None for other cases
386        }
387    }
388
389    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
390        match self {
391            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
392            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
393            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
394            Epoch::Second => ts.convert_to(TimeUnit::Second),
395        }
396    }
397}
398
399impl Display for Epoch {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        match self {
402            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
403            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
404            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
405            Epoch::Second => write!(f, "Epoch::Second"),
406        }
407    }
408}
409
410#[derive(Serialize, Deserialize, Debug)]
411pub enum HttpResponse {
412    Arrow(ArrowResponse),
413    Csv(CsvResponse),
414    Table(TableResponse),
415    Error(ErrorResponse),
416    GreptimedbV1(GreptimedbV1Response),
417    InfluxdbV1(InfluxdbV1Response),
418    Json(JsonResponse),
419    Null(NullResponse),
420}
421
422impl HttpResponse {
423    pub fn with_execution_time(self, execution_time: u64) -> Self {
424        match self {
425            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
426            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
427            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
428            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
429            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
430            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
431            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
432            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
433        }
434    }
435
436    pub fn with_limit(self, limit: usize) -> Self {
437        match self {
438            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
439            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
440            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
441            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
442            _ => self,
443        }
444    }
445}
446
447pub fn process_with_limit(
448    mut outputs: Vec<GreptimeQueryOutput>,
449    limit: usize,
450) -> Vec<GreptimeQueryOutput> {
451    outputs
452        .drain(..)
453        .map(|data| match data {
454            GreptimeQueryOutput::Records(mut records) => {
455                if records.rows.len() > limit {
456                    records.rows.truncate(limit);
457                    records.total_rows = limit;
458                }
459                GreptimeQueryOutput::Records(records)
460            }
461            _ => data,
462        })
463        .collect()
464}
465
466impl IntoResponse for HttpResponse {
467    fn into_response(self) -> Response {
468        match self {
469            HttpResponse::Arrow(resp) => resp.into_response(),
470            HttpResponse::Csv(resp) => resp.into_response(),
471            HttpResponse::Table(resp) => resp.into_response(),
472            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
473            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
474            HttpResponse::Json(resp) => resp.into_response(),
475            HttpResponse::Null(resp) => resp.into_response(),
476            HttpResponse::Error(resp) => resp.into_response(),
477        }
478    }
479}
480
481impl From<ArrowResponse> for HttpResponse {
482    fn from(value: ArrowResponse) -> Self {
483        HttpResponse::Arrow(value)
484    }
485}
486
487impl From<CsvResponse> for HttpResponse {
488    fn from(value: CsvResponse) -> Self {
489        HttpResponse::Csv(value)
490    }
491}
492
493impl From<TableResponse> for HttpResponse {
494    fn from(value: TableResponse) -> Self {
495        HttpResponse::Table(value)
496    }
497}
498
499impl From<ErrorResponse> for HttpResponse {
500    fn from(value: ErrorResponse) -> Self {
501        HttpResponse::Error(value)
502    }
503}
504
505impl From<GreptimedbV1Response> for HttpResponse {
506    fn from(value: GreptimedbV1Response) -> Self {
507        HttpResponse::GreptimedbV1(value)
508    }
509}
510
511impl From<InfluxdbV1Response> for HttpResponse {
512    fn from(value: InfluxdbV1Response) -> Self {
513        HttpResponse::InfluxdbV1(value)
514    }
515}
516
517impl From<JsonResponse> for HttpResponse {
518    fn from(value: JsonResponse) -> Self {
519        HttpResponse::Json(value)
520    }
521}
522
523impl From<NullResponse> for HttpResponse {
524    fn from(value: NullResponse) -> Self {
525        HttpResponse::Null(value)
526    }
527}
528
529#[derive(Clone)]
530pub struct ApiState {
531    pub sql_handler: ServerSqlQueryHandlerRef,
532}
533
534#[derive(Clone)]
535pub struct GreptimeOptionsConfigState {
536    pub greptime_config_options: String,
537}
538
539#[derive(Default)]
540pub struct HttpServerBuilder {
541    options: HttpOptions,
542    plugins: Plugins,
543    user_provider: Option<UserProviderRef>,
544    router: Router,
545}
546
547impl HttpServerBuilder {
548    pub fn new(options: HttpOptions) -> Self {
549        Self {
550            options,
551            plugins: Plugins::default(),
552            user_provider: None,
553            router: Router::new(),
554        }
555    }
556
557    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
558        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
559
560        Self {
561            router: self
562                .router
563                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
564            ..self
565        }
566    }
567
568    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
569        let logs_router = HttpServer::route_logs(logs_handler);
570
571        Self {
572            router: self
573                .router
574                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
575            ..self
576        }
577    }
578
579    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
580        Self {
581            router: self.router.nest(
582                &format!("/{HTTP_API_VERSION}/opentsdb"),
583                HttpServer::route_opentsdb(handler),
584            ),
585            ..self
586        }
587    }
588
589    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
590        Self {
591            router: self.router.nest(
592                &format!("/{HTTP_API_VERSION}/influxdb"),
593                HttpServer::route_influxdb(handler),
594            ),
595            ..self
596        }
597    }
598
599    pub fn with_prom_handler(
600        self,
601        handler: PromStoreProtocolHandlerRef,
602        pipeline_handler: Option<PipelineHandlerRef>,
603        prom_store_with_metric_engine: bool,
604        prom_validation_mode: PromValidationMode,
605    ) -> Self {
606        let state = PromStoreState {
607            prom_store_handler: handler,
608            pipeline_handler,
609            prom_store_with_metric_engine,
610            prom_validation_mode,
611        };
612
613        Self {
614            router: self.router.nest(
615                &format!("/{HTTP_API_VERSION}/prometheus"),
616                HttpServer::route_prom(state),
617            ),
618            ..self
619        }
620    }
621
622    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
623        Self {
624            router: self.router.nest(
625                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
626                HttpServer::route_prometheus(handler),
627            ),
628            ..self
629        }
630    }
631
632    pub fn with_otlp_handler(
633        self,
634        handler: OpenTelemetryProtocolHandlerRef,
635        with_metric_engine: bool,
636    ) -> Self {
637        Self {
638            router: self.router.nest(
639                &format!("/{HTTP_API_VERSION}/otlp"),
640                HttpServer::route_otlp(handler, with_metric_engine),
641            ),
642            ..self
643        }
644    }
645
646    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
647        Self {
648            user_provider: Some(user_provider),
649            ..self
650        }
651    }
652
653    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
654        Self {
655            router: self.router.merge(HttpServer::route_metrics(handler)),
656            ..self
657        }
658    }
659
660    pub fn with_log_ingest_handler(
661        self,
662        handler: PipelineHandlerRef,
663        validator: Option<LogValidatorRef>,
664        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
665    ) -> Self {
666        let log_state = LogState {
667            log_handler: handler,
668            log_validator: validator,
669            ingest_interceptor,
670        };
671
672        let router = self.router.nest(
673            &format!("/{HTTP_API_VERSION}"),
674            HttpServer::route_pipelines(log_state.clone()),
675        );
676        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
677        let router = router.nest(
678            &format!("/{HTTP_API_VERSION}/events"),
679            #[allow(deprecated)]
680            HttpServer::route_log_deprecated(log_state.clone()),
681        );
682
683        let router = router.nest(
684            &format!("/{HTTP_API_VERSION}/loki"),
685            HttpServer::route_loki(log_state.clone()),
686        );
687
688        let router = router.nest(
689            &format!("/{HTTP_API_VERSION}/elasticsearch"),
690            HttpServer::route_elasticsearch(log_state.clone()),
691        );
692
693        let router = router.nest(
694            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
695            Router::new()
696                .route("/", routing::get(elasticsearch::handle_get_version))
697                .with_state(log_state),
698        );
699
700        Self { router, ..self }
701    }
702
703    pub fn with_plugins(self, plugins: Plugins) -> Self {
704        Self { plugins, ..self }
705    }
706
707    pub fn with_greptime_config_options(self, opts: String) -> Self {
708        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
709            greptime_config_options: opts,
710        });
711
712        Self {
713            router: self.router.merge(config_router),
714            ..self
715        }
716    }
717
718    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
719        Self {
720            router: self.router.nest(
721                &format!("/{HTTP_API_VERSION}/jaeger"),
722                HttpServer::route_jaeger(handler),
723            ),
724            ..self
725        }
726    }
727
728    pub fn with_extra_router(self, router: Router) -> Self {
729        Self {
730            router: self.router.merge(router),
731            ..self
732        }
733    }
734
735    pub fn build(self) -> HttpServer {
736        let memory_limiter =
737            RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize);
738        HttpServer {
739            options: self.options,
740            user_provider: self.user_provider,
741            shutdown_tx: Mutex::new(None),
742            plugins: self.plugins,
743            router: StdMutex::new(self.router),
744            bind_addr: None,
745            memory_limiter,
746        }
747    }
748}
749
750impl HttpServer {
751    /// Gets the router and adds necessary root routes (health, status, dashboard).
752    pub fn make_app(&self) -> Router {
753        let mut router = {
754            let router = self.router.lock().unwrap();
755            router.clone()
756        };
757
758        router = router
759            .route(
760                "/health",
761                routing::get(handler::health).post(handler::health),
762            )
763            .route(
764                &format!("/{HTTP_API_VERSION}/health"),
765                routing::get(handler::health).post(handler::health),
766            )
767            .route(
768                "/ready",
769                routing::get(handler::health).post(handler::health),
770            );
771
772        router = router.route("/status", routing::get(handler::status));
773
774        #[cfg(feature = "dashboard")]
775        {
776            if !self.options.disable_dashboard {
777                info!("Enable dashboard service at '/dashboard'");
778                // redirect /dashboard to /dashboard/
779                router = router.route(
780                    "/dashboard",
781                    routing::get(|uri: axum::http::uri::Uri| async move {
782                        let path = uri.path();
783                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
784
785                        let new_uri = format!("{}/{}", path, query);
786                        axum::response::Redirect::permanent(&new_uri)
787                    }),
788                );
789
790                // "/dashboard" and "/dashboard/" are two different paths in Axum.
791                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
792                // So we explicitly route "/dashboard/" here.
793                router = router
794                    .route(
795                        "/dashboard/",
796                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
797                    )
798                    .route(
799                        "/dashboard/{*x}",
800                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
801                    );
802            }
803        }
804
805        // Add a layer to collect HTTP metrics for axum.
806        router = router.route_layer(middleware::from_fn(http_metrics_layer));
807
808        router
809    }
810
811    /// Attaches middlewares and debug routes to the router.
812    /// Callers should call this method after [HttpServer::make_app()].
813    pub fn build(&self, router: Router) -> Result<Router> {
814        let timeout_layer = if self.options.timeout != Duration::default() {
815            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
816        } else {
817            info!("HTTP server timeout is disabled");
818            None
819        };
820        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
821            Some(
822                ServiceBuilder::new()
823                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
824            )
825        } else {
826            info!("HTTP server body limit is disabled");
827            None
828        };
829        let cors_layer = if self.options.enable_cors {
830            Some(
831                CorsLayer::new()
832                    .allow_methods([
833                        Method::GET,
834                        Method::POST,
835                        Method::PUT,
836                        Method::DELETE,
837                        Method::HEAD,
838                    ])
839                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
840                        AllowOrigin::from(Any)
841                    } else {
842                        AllowOrigin::from(
843                            self.options
844                                .cors_allowed_origins
845                                .iter()
846                                .map(|s| {
847                                    HeaderValue::from_str(s.as_str())
848                                        .context(InvalidHeaderValueSnafu)
849                                })
850                                .collect::<Result<Vec<HeaderValue>>>()?,
851                        )
852                    })
853                    .allow_headers(Any),
854            )
855        } else {
856            info!("HTTP server cross-origin is disabled");
857            None
858        };
859
860        Ok(router
861            // middlewares
862            .layer(
863                ServiceBuilder::new()
864                    // disable on failure tracing. because printing out isn't very helpful,
865                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
866                    .layer(TraceLayer::new_for_http().on_failure(()))
867                    .option_layer(cors_layer)
868                    .option_layer(timeout_layer)
869                    .option_layer(body_limit_layer)
870                    // memory limit layer - must be before body is consumed
871                    .layer(middleware::from_fn_with_state(
872                        self.memory_limiter.clone(),
873                        memory_limit::memory_limit_middleware,
874                    ))
875                    // auth layer
876                    .layer(middleware::from_fn_with_state(
877                        AuthState::new(self.user_provider.clone()),
878                        authorize::check_http_auth,
879                    ))
880                    .layer(middleware::from_fn(hints::extract_hints))
881                    .layer(middleware::from_fn(
882                        read_preference::extract_read_preference,
883                    )),
884            )
885            // Handlers for debug, we don't expect a timeout.
886            .nest(
887                "/debug",
888                Router::new()
889                    // handler for changing log level dynamically
890                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
891                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
892                    .nest(
893                        "/prof",
894                        Router::new()
895                            .route("/cpu", routing::post(pprof::pprof_handler))
896                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
897                            .route(
898                                "/mem/activate",
899                                routing::post(mem_prof::activate_heap_prof_handler),
900                            )
901                            .route(
902                                "/mem/deactivate",
903                                routing::post(mem_prof::deactivate_heap_prof_handler),
904                            )
905                            .route(
906                                "/mem/status",
907                                routing::get(mem_prof::heap_prof_status_handler),
908                            ) // jemalloc gdump flag status and toggle
909                            .route(
910                                "/mem/gdump",
911                                routing::get(mem_prof::gdump_status_handler)
912                                    .post(mem_prof::gdump_toggle_handler),
913                            ),
914                    ),
915            ))
916    }
917
918    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
919        Router::new()
920            .route("/metrics", routing::get(handler::metrics))
921            .with_state(metrics_handler)
922    }
923
924    fn route_loki<S>(log_state: LogState) -> Router<S> {
925        Router::new()
926            .route("/api/v1/push", routing::post(loki::loki_ingest))
927            .layer(
928                ServiceBuilder::new()
929                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
930            )
931            .with_state(log_state)
932    }
933
934    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
935        Router::new()
936            // Return fake responsefor HEAD '/' request.
937            .route(
938                "/",
939                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
940            )
941            // Return fake response for Elasticsearch version request.
942            .route("/", routing::get(elasticsearch::handle_get_version))
943            // Return fake response for Elasticsearch license request.
944            .route("/_license", routing::get(elasticsearch::handle_get_license))
945            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
946            .route(
947                "/{index}/_bulk",
948                routing::post(elasticsearch::handle_bulk_api_with_index),
949            )
950            // Return fake response for Elasticsearch ilm request.
951            .route(
952                "/_ilm/policy/{*path}",
953                routing::any((
954                    HttpStatusCode::OK,
955                    elasticsearch::elasticsearch_headers(),
956                    axum::Json(serde_json::json!({})),
957                )),
958            )
959            // Return fake response for Elasticsearch index template request.
960            .route(
961                "/_index_template/{*path}",
962                routing::any((
963                    HttpStatusCode::OK,
964                    elasticsearch::elasticsearch_headers(),
965                    axum::Json(serde_json::json!({})),
966                )),
967            )
968            // Return fake response for Elasticsearch ingest pipeline request.
969            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
970            .route(
971                "/_ingest/{*path}",
972                routing::any((
973                    HttpStatusCode::OK,
974                    elasticsearch::elasticsearch_headers(),
975                    axum::Json(serde_json::json!({})),
976                )),
977            )
978            // Return fake response for Elasticsearch nodes discovery request.
979            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
980            .route(
981                "/_nodes/{*path}",
982                routing::any((
983                    HttpStatusCode::OK,
984                    elasticsearch::elasticsearch_headers(),
985                    axum::Json(serde_json::json!({})),
986                )),
987            )
988            // Return fake response for Logstash APIs requests.
989            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
990            .route(
991                "/logstash/{*path}",
992                routing::any((
993                    HttpStatusCode::OK,
994                    elasticsearch::elasticsearch_headers(),
995                    axum::Json(serde_json::json!({})),
996                )),
997            )
998            .route(
999                "/_logstash/{*path}",
1000                routing::any((
1001                    HttpStatusCode::OK,
1002                    elasticsearch::elasticsearch_headers(),
1003                    axum::Json(serde_json::json!({})),
1004                )),
1005            )
1006            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1007            .with_state(log_state)
1008    }
1009
1010    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1011    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1012        Router::new()
1013            .route("/logs", routing::post(event::log_ingester))
1014            .route(
1015                "/pipelines/{pipeline_name}",
1016                routing::get(event::query_pipeline),
1017            )
1018            .route(
1019                "/pipelines/{pipeline_name}",
1020                routing::post(event::add_pipeline),
1021            )
1022            .route(
1023                "/pipelines/{pipeline_name}",
1024                routing::delete(event::delete_pipeline),
1025            )
1026            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1027            .layer(
1028                ServiceBuilder::new()
1029                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1030            )
1031            .with_state(log_state)
1032    }
1033
1034    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1035        Router::new()
1036            .route("/ingest", routing::post(event::log_ingester))
1037            .route(
1038                "/pipelines/{pipeline_name}",
1039                routing::get(event::query_pipeline),
1040            )
1041            .route(
1042                "/pipelines/{pipeline_name}/ddl",
1043                routing::get(event::query_pipeline_ddl),
1044            )
1045            .route(
1046                "/pipelines/{pipeline_name}",
1047                routing::post(event::add_pipeline),
1048            )
1049            .route(
1050                "/pipelines/{pipeline_name}",
1051                routing::delete(event::delete_pipeline),
1052            )
1053            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1054            .layer(
1055                ServiceBuilder::new()
1056                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1057            )
1058            .with_state(log_state)
1059    }
1060
1061    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1062        Router::new()
1063            .route("/sql", routing::get(handler::sql).post(handler::sql))
1064            .route(
1065                "/sql/parse",
1066                routing::get(handler::sql_parse).post(handler::sql_parse),
1067            )
1068            .route(
1069                "/sql/format",
1070                routing::get(handler::sql_format).post(handler::sql_format),
1071            )
1072            .route(
1073                "/promql",
1074                routing::get(handler::promql).post(handler::promql),
1075            )
1076            .with_state(api_state)
1077    }
1078
1079    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1080        Router::new()
1081            .route("/logs", routing::get(logs::logs).post(logs::logs))
1082            .with_state(log_handler)
1083    }
1084
1085    /// Route Prometheus [HTTP API].
1086    ///
1087    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1088    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1089        Router::new()
1090            .route(
1091                "/format_query",
1092                routing::post(format_query).get(format_query),
1093            )
1094            .route("/status/buildinfo", routing::get(build_info_query))
1095            .route("/query", routing::post(instant_query).get(instant_query))
1096            .route("/query_range", routing::post(range_query).get(range_query))
1097            .route("/labels", routing::post(labels_query).get(labels_query))
1098            .route("/series", routing::post(series_query).get(series_query))
1099            .route("/parse_query", routing::post(parse_query).get(parse_query))
1100            .route(
1101                "/label/{label_name}/values",
1102                routing::get(label_values_query),
1103            )
1104            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1105            .with_state(prometheus_handler)
1106    }
1107
1108    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1109    /// called `prom_store`.
1110    ///
1111    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1112    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1113    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1114        Router::new()
1115            .route("/read", routing::post(prom_store::remote_read))
1116            .route("/write", routing::post(prom_store::remote_write))
1117            .with_state(state)
1118    }
1119
1120    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1121        Router::new()
1122            .route("/write", routing::post(influxdb_write_v1))
1123            .route("/api/v2/write", routing::post(influxdb_write_v2))
1124            .layer(
1125                ServiceBuilder::new()
1126                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1127            )
1128            .route("/ping", routing::get(influxdb_ping))
1129            .route("/health", routing::get(influxdb_health))
1130            .with_state(influxdb_handler)
1131    }
1132
1133    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1134        Router::new()
1135            .route("/api/put", routing::post(opentsdb::put))
1136            .with_state(opentsdb_handler)
1137    }
1138
1139    fn route_otlp<S>(
1140        otlp_handler: OpenTelemetryProtocolHandlerRef,
1141        with_metric_engine: bool,
1142    ) -> Router<S> {
1143        Router::new()
1144            .route("/v1/metrics", routing::post(otlp::metrics))
1145            .route("/v1/traces", routing::post(otlp::traces))
1146            .route("/v1/logs", routing::post(otlp::logs))
1147            .layer(
1148                ServiceBuilder::new()
1149                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1150            )
1151            .with_state(OtlpState {
1152                with_metric_engine,
1153                handler: otlp_handler,
1154            })
1155    }
1156
1157    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1158        Router::new()
1159            .route("/config", routing::get(handler::config))
1160            .with_state(state)
1161    }
1162
1163    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1164        Router::new()
1165            .route("/api/services", routing::get(jaeger::handle_get_services))
1166            .route(
1167                "/api/services/{service_name}/operations",
1168                routing::get(jaeger::handle_get_operations_by_service),
1169            )
1170            .route(
1171                "/api/operations",
1172                routing::get(jaeger::handle_get_operations),
1173            )
1174            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1175            .route(
1176                "/api/traces/{trace_id}",
1177                routing::get(jaeger::handle_get_trace),
1178            )
1179            .with_state(handler)
1180    }
1181}
1182
1183pub const HTTP_SERVER: &str = "HTTP_SERVER";
1184
1185#[async_trait]
1186impl Server for HttpServer {
1187    async fn shutdown(&self) -> Result<()> {
1188        let mut shutdown_tx = self.shutdown_tx.lock().await;
1189        if let Some(tx) = shutdown_tx.take()
1190            && tx.send(()).is_err()
1191        {
1192            info!("Receiver dropped, the HTTP server has already exited");
1193        }
1194        info!("Shutdown HTTP server");
1195
1196        Ok(())
1197    }
1198
1199    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1200        let (tx, rx) = oneshot::channel();
1201        let serve = {
1202            let mut shutdown_tx = self.shutdown_tx.lock().await;
1203            ensure!(
1204                shutdown_tx.is_none(),
1205                AlreadyStartedSnafu { server: "HTTP" }
1206            );
1207
1208            let mut app = self.make_app();
1209            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1210                app = configurator
1211                    .configure_http(app, ())
1212                    .await
1213                    .context(OtherSnafu)?;
1214            }
1215            let app = self.build(app)?;
1216            let listener = tokio::net::TcpListener::bind(listening)
1217                .await
1218                .context(AddressBindSnafu { addr: listening })?
1219                .tap_io(|tcp_stream| {
1220                    if let Err(e) = tcp_stream.set_nodelay(true) {
1221                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1222                    }
1223                });
1224            let serve = axum::serve(listener, app.into_make_service());
1225
1226            // FIXME(yingwen): Support keepalive.
1227            // See:
1228            // - https://github.com/tokio-rs/axum/discussions/2939
1229            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1230            // let server = axum::Server::try_bind(&listening)
1231            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1232            //     .tcp_nodelay(true)
1233            //     // Enable TCP keepalive to close the dangling established connections.
1234            //     // It's configured to let the keepalive probes first send after the connection sits
1235            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1236            //     // So the connection will be closed after roughly 1 hour.
1237            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1238            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1239            //     .tcp_keepalive_retries(Some(6))
1240            //     .serve(app.into_make_service());
1241
1242            *shutdown_tx = Some(tx);
1243
1244            serve
1245        };
1246        let listening = serve.local_addr().context(InternalIoSnafu)?;
1247        info!("HTTP server is bound to {}", listening);
1248
1249        common_runtime::spawn_global(async move {
1250            if let Err(e) = serve
1251                .with_graceful_shutdown(rx.map(drop))
1252                .await
1253                .context(InternalIoSnafu)
1254            {
1255                error!(e; "Failed to shutdown http server");
1256            }
1257        });
1258
1259        self.bind_addr = Some(listening);
1260        Ok(())
1261    }
1262
1263    fn name(&self) -> &str {
1264        HTTP_SERVER
1265    }
1266
1267    fn bind_addr(&self) -> Option<SocketAddr> {
1268        self.bind_addr
1269    }
1270}
1271
1272#[cfg(test)]
1273mod test {
1274    use std::future::pending;
1275    use std::io::Cursor;
1276    use std::sync::Arc;
1277
1278    use arrow_ipc::reader::FileReader;
1279    use arrow_schema::DataType;
1280    use axum::handler::Handler;
1281    use axum::http::StatusCode;
1282    use axum::routing::get;
1283    use common_query::Output;
1284    use common_recordbatch::RecordBatches;
1285    use datafusion_expr::LogicalPlan;
1286    use datatypes::prelude::*;
1287    use datatypes::schema::{ColumnSchema, Schema};
1288    use datatypes::vectors::{StringVector, UInt32Vector};
1289    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1290    use query::parser::PromQuery;
1291    use query::query_engine::DescribeResult;
1292    use session::context::QueryContextRef;
1293    use sql::statements::statement::Statement;
1294    use tokio::sync::mpsc;
1295    use tokio::time::Instant;
1296
1297    use super::*;
1298    use crate::error::Error;
1299    use crate::http::test_helpers::TestClient;
1300    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1301
1302    struct DummyInstance {
1303        _tx: mpsc::Sender<(String, Vec<u8>)>,
1304    }
1305
1306    #[async_trait]
1307    impl SqlQueryHandler for DummyInstance {
1308        type Error = Error;
1309
1310        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1311            unimplemented!()
1312        }
1313
1314        async fn do_promql_query(
1315            &self,
1316            _: &PromQuery,
1317            _: QueryContextRef,
1318        ) -> Vec<std::result::Result<Output, Self::Error>> {
1319            unimplemented!()
1320        }
1321
1322        async fn do_exec_plan(
1323            &self,
1324            _stmt: Option<Statement>,
1325            _plan: LogicalPlan,
1326            _query_ctx: QueryContextRef,
1327        ) -> std::result::Result<Output, Self::Error> {
1328            unimplemented!()
1329        }
1330
1331        async fn do_describe(
1332            &self,
1333            _stmt: sql::statements::statement::Statement,
1334            _query_ctx: QueryContextRef,
1335        ) -> Result<Option<DescribeResult>> {
1336            unimplemented!()
1337        }
1338
1339        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1340            Ok(true)
1341        }
1342    }
1343
1344    fn timeout() -> DynamicTimeoutLayer {
1345        DynamicTimeoutLayer::new(Duration::from_millis(10))
1346    }
1347
1348    async fn forever() {
1349        pending().await
1350    }
1351
1352    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1353        make_test_app_custom(tx, HttpOptions::default())
1354    }
1355
1356    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1357        let instance = Arc::new(DummyInstance { _tx: tx });
1358        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1359        let server = HttpServerBuilder::new(options)
1360            .with_sql_handler(sql_instance)
1361            .build();
1362        server.build(server.make_app()).unwrap().route(
1363            "/test/timeout",
1364            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1365        )
1366    }
1367
1368    #[tokio::test]
1369    pub async fn test_cors() {
1370        // cors is on by default
1371        let (tx, _rx) = mpsc::channel(100);
1372        let app = make_test_app(tx);
1373        let client = TestClient::new(app).await;
1374
1375        let res = client.get("/health").send().await;
1376
1377        assert_eq!(res.status(), StatusCode::OK);
1378        assert_eq!(
1379            res.headers()
1380                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1381                .expect("expect cors header origin"),
1382            "*"
1383        );
1384
1385        let res = client.get("/v1/health").send().await;
1386
1387        assert_eq!(res.status(), StatusCode::OK);
1388        assert_eq!(
1389            res.headers()
1390                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1391                .expect("expect cors header origin"),
1392            "*"
1393        );
1394
1395        let res = client
1396            .options("/health")
1397            .header("Access-Control-Request-Headers", "x-greptime-auth")
1398            .header("Access-Control-Request-Method", "DELETE")
1399            .header("Origin", "https://example.com")
1400            .send()
1401            .await;
1402        assert_eq!(res.status(), StatusCode::OK);
1403        assert_eq!(
1404            res.headers()
1405                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1406                .expect("expect cors header origin"),
1407            "*"
1408        );
1409        assert_eq!(
1410            res.headers()
1411                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1412                .expect("expect cors header headers"),
1413            "*"
1414        );
1415        assert_eq!(
1416            res.headers()
1417                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1418                .expect("expect cors header methods"),
1419            "GET,POST,PUT,DELETE,HEAD"
1420        );
1421    }
1422
1423    #[tokio::test]
1424    pub async fn test_cors_custom_origins() {
1425        // cors is on by default
1426        let (tx, _rx) = mpsc::channel(100);
1427        let origin = "https://example.com";
1428
1429        let options = HttpOptions {
1430            cors_allowed_origins: vec![origin.to_string()],
1431            ..Default::default()
1432        };
1433
1434        let app = make_test_app_custom(tx, options);
1435        let client = TestClient::new(app).await;
1436
1437        let res = client.get("/health").header("Origin", origin).send().await;
1438
1439        assert_eq!(res.status(), StatusCode::OK);
1440        assert_eq!(
1441            res.headers()
1442                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1443                .expect("expect cors header origin"),
1444            origin
1445        );
1446
1447        let res = client
1448            .get("/health")
1449            .header("Origin", "https://notallowed.com")
1450            .send()
1451            .await;
1452
1453        assert_eq!(res.status(), StatusCode::OK);
1454        assert!(
1455            !res.headers()
1456                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1457        );
1458    }
1459
1460    #[tokio::test]
1461    pub async fn test_cors_disabled() {
1462        // cors is on by default
1463        let (tx, _rx) = mpsc::channel(100);
1464
1465        let options = HttpOptions {
1466            enable_cors: false,
1467            ..Default::default()
1468        };
1469
1470        let app = make_test_app_custom(tx, options);
1471        let client = TestClient::new(app).await;
1472
1473        let res = client.get("/health").send().await;
1474
1475        assert_eq!(res.status(), StatusCode::OK);
1476        assert!(
1477            !res.headers()
1478                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1479        );
1480    }
1481
1482    #[test]
1483    fn test_http_options_default() {
1484        let default = HttpOptions::default();
1485        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1486        assert_eq!(Duration::from_secs(0), default.timeout)
1487    }
1488
1489    #[tokio::test]
1490    async fn test_http_server_request_timeout() {
1491        common_telemetry::init_default_ut_logging();
1492
1493        let (tx, _rx) = mpsc::channel(100);
1494        let app = make_test_app(tx);
1495        let client = TestClient::new(app).await;
1496        let res = client.get("/test/timeout").send().await;
1497        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1498
1499        let now = Instant::now();
1500        let res = client
1501            .get("/test/timeout")
1502            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1503            .send()
1504            .await;
1505        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1506        let elapsed = now.elapsed();
1507        assert!(elapsed > Duration::from_millis(15));
1508
1509        tokio::time::timeout(
1510            Duration::from_millis(15),
1511            client
1512                .get("/test/timeout")
1513                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1514                .send(),
1515        )
1516        .await
1517        .unwrap_err();
1518
1519        tokio::time::timeout(
1520            Duration::from_millis(15),
1521            client
1522                .get("/test/timeout")
1523                .header(
1524                    GREPTIME_DB_HEADER_TIMEOUT,
1525                    humantime::format_duration(Duration::default()).to_string(),
1526                )
1527                .send(),
1528        )
1529        .await
1530        .unwrap_err();
1531    }
1532
1533    #[tokio::test]
1534    async fn test_schema_for_empty_response() {
1535        let column_schemas = vec![
1536            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1537            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1538        ];
1539        let schema = Arc::new(Schema::new(column_schemas));
1540
1541        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1542        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1543
1544        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1545        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1546            let json_output = &json_resp.output[0];
1547            if let GreptimeQueryOutput::Records(r) = json_output {
1548                assert_eq!(r.num_rows(), 0);
1549                assert_eq!(r.num_cols(), 2);
1550                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1551                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1552            } else {
1553                panic!("invalid output type");
1554            }
1555        } else {
1556            panic!("invalid format")
1557        }
1558    }
1559
1560    #[tokio::test]
1561    async fn test_recordbatches_conversion() {
1562        let column_schemas = vec![
1563            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1564            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1565        ];
1566        let schema = Arc::new(Schema::new(column_schemas));
1567        let columns: Vec<VectorRef> = vec![
1568            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1569            Arc::new(StringVector::from(vec![
1570                None,
1571                Some("hello"),
1572                Some("greptime"),
1573                None,
1574            ])),
1575        ];
1576        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1577
1578        for format in [
1579            ResponseFormat::GreptimedbV1,
1580            ResponseFormat::InfluxdbV1,
1581            ResponseFormat::Csv(true, true),
1582            ResponseFormat::Table,
1583            ResponseFormat::Arrow,
1584            ResponseFormat::Json,
1585            ResponseFormat::Null,
1586        ] {
1587            let recordbatches =
1588                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1589            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1590            let json_resp = match format {
1591                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1592                ResponseFormat::Csv(with_names, with_types) => {
1593                    CsvResponse::from_output(outputs, with_names, with_types).await
1594                }
1595                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1596                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1597                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1598                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1599                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1600            };
1601
1602            match json_resp {
1603                HttpResponse::GreptimedbV1(resp) => {
1604                    let json_output = &resp.output[0];
1605                    if let GreptimeQueryOutput::Records(r) = json_output {
1606                        assert_eq!(r.num_rows(), 4);
1607                        assert_eq!(r.num_cols(), 2);
1608                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1609                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1610                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1611                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1612                    } else {
1613                        panic!("invalid output type");
1614                    }
1615                }
1616                HttpResponse::InfluxdbV1(resp) => {
1617                    let json_output = &resp.results()[0];
1618                    assert_eq!(json_output.num_rows(), 4);
1619                    assert_eq!(json_output.num_cols(), 2);
1620                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1621                    assert_eq!(
1622                        json_output.series[0].values[0][0],
1623                        serde_json::Value::from(1)
1624                    );
1625                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1626                }
1627                HttpResponse::Csv(resp) => {
1628                    let output = &resp.output()[0];
1629                    if let GreptimeQueryOutput::Records(r) = output {
1630                        assert_eq!(r.num_rows(), 4);
1631                        assert_eq!(r.num_cols(), 2);
1632                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1633                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1634                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1635                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1636                    } else {
1637                        panic!("invalid output type");
1638                    }
1639                }
1640
1641                HttpResponse::Table(resp) => {
1642                    let output = &resp.output()[0];
1643                    if let GreptimeQueryOutput::Records(r) = output {
1644                        assert_eq!(r.num_rows(), 4);
1645                        assert_eq!(r.num_cols(), 2);
1646                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1647                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1648                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1649                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1650                    } else {
1651                        panic!("invalid output type");
1652                    }
1653                }
1654
1655                HttpResponse::Arrow(resp) => {
1656                    let output = resp.data;
1657                    let mut reader =
1658                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1659                    let schema = reader.schema();
1660                    assert_eq!(schema.fields[0].name(), "numbers");
1661                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1662                    assert_eq!(schema.fields[1].name(), "strings");
1663                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1664
1665                    let rb = reader.next().unwrap().expect("read record batch failed");
1666                    assert_eq!(rb.num_columns(), 2);
1667                    assert_eq!(rb.num_rows(), 4);
1668                }
1669
1670                HttpResponse::Json(resp) => {
1671                    let output = &resp.output()[0];
1672                    if let GreptimeQueryOutput::Records(r) = output {
1673                        assert_eq!(r.num_rows(), 4);
1674                        assert_eq!(r.num_cols(), 2);
1675                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1676                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1677                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1678                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1679                    } else {
1680                        panic!("invalid output type");
1681                    }
1682                }
1683
1684                HttpResponse::Null(resp) => {
1685                    assert_eq!(resp.rows(), 4);
1686                }
1687
1688                HttpResponse::Error(err) => unreachable!("{err:?}"),
1689            }
1690        }
1691    }
1692
1693    #[test]
1694    fn test_response_format_misc() {
1695        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1696        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1697        assert_eq!(
1698            ResponseFormat::parse("csv"),
1699            Some(ResponseFormat::Csv(false, false))
1700        );
1701        assert_eq!(
1702            ResponseFormat::parse("csvwithnames"),
1703            Some(ResponseFormat::Csv(true, false))
1704        );
1705        assert_eq!(
1706            ResponseFormat::parse("csvwithnamesandtypes"),
1707            Some(ResponseFormat::Csv(true, true))
1708        );
1709        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1710        assert_eq!(
1711            ResponseFormat::parse("greptimedb_v1"),
1712            Some(ResponseFormat::GreptimedbV1)
1713        );
1714        assert_eq!(
1715            ResponseFormat::parse("influxdb_v1"),
1716            Some(ResponseFormat::InfluxdbV1)
1717        );
1718        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1719        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1720
1721        // invalid formats
1722        assert_eq!(ResponseFormat::parse("invalid"), None);
1723        assert_eq!(ResponseFormat::parse(""), None);
1724        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1725
1726        // as str
1727        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1728        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1729        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1730        assert_eq!(ResponseFormat::Table.as_str(), "table");
1731        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1732        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1733        assert_eq!(ResponseFormat::Json.as_str(), "json");
1734        assert_eq!(ResponseFormat::Null.as_str(), "null");
1735        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1736    }
1737}