servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Display;
17use std::net::SocketAddr;
18use std::sync::Mutex as StdMutex;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use auth::UserProviderRef;
23use axum::extract::DefaultBodyLimit;
24use axum::http::StatusCode as HttpStatusCode;
25use axum::response::{IntoResponse, Response};
26use axum::serve::ListenerExt;
27use axum::{middleware, routing, Router};
28use common_base::readable_size::ReadableSize;
29use common_base::Plugins;
30use common_recordbatch::RecordBatch;
31use common_telemetry::{debug, error, info};
32use common_time::timestamp::TimeUnit;
33use common_time::Timestamp;
34use datatypes::data_type::DataType;
35use datatypes::schema::SchemaRef;
36use datatypes::value::transform_value_ref_to_json_value;
37use event::{LogState, LogValidatorRef};
38use futures::FutureExt;
39use http::{HeaderValue, Method};
40use prost::DecodeError;
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use snafu::{ensure, ResultExt};
44use tokio::sync::oneshot::{self, Sender};
45use tokio::sync::Mutex;
46use tower::ServiceBuilder;
47use tower_http::compression::CompressionLayer;
48use tower_http::cors::{AllowOrigin, Any, CorsLayer};
49use tower_http::decompression::RequestDecompressionLayer;
50use tower_http::trace::TraceLayer;
51
52use self::authorize::AuthState;
53use self::result::table_result::TableResponse;
54use crate::configurator::ConfiguratorRef;
55use crate::elasticsearch;
56use crate::error::{
57    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu, Result,
58    ToJsonSnafu,
59};
60use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
61use crate::http::prom_store::PromStoreState;
62use crate::http::prometheus::{
63    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
64    range_query, series_query,
65};
66use crate::http::result::arrow_result::ArrowResponse;
67use crate::http::result::csv_result::CsvResponse;
68use crate::http::result::error_result::ErrorResponse;
69use crate::http::result::greptime_result_v1::GreptimedbV1Response;
70use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
71use crate::http::result::json_result::JsonResponse;
72use crate::interceptor::LogIngestInterceptorRef;
73use crate::metrics::http_metrics_layer;
74use crate::metrics_handler::MetricsHandler;
75use crate::prometheus_handler::PrometheusHandlerRef;
76use crate::query_handler::sql::ServerSqlQueryHandlerRef;
77use crate::query_handler::{
78    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
79    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
80    PromStoreProtocolHandlerRef,
81};
82use crate::server::Server;
83
84pub mod authorize;
85#[cfg(feature = "dashboard")]
86mod dashboard;
87pub mod dyn_log;
88pub mod event;
89mod extractor;
90pub mod handler;
91pub mod header;
92pub mod influxdb;
93pub mod jaeger;
94pub mod logs;
95pub mod loki;
96pub mod mem_prof;
97pub mod opentsdb;
98pub mod otlp;
99pub mod pprof;
100pub mod prom_store;
101pub mod prometheus;
102pub mod result;
103mod timeout;
104
105pub(crate) use timeout::DynamicTimeoutLayer;
106
107mod hints;
108mod read_preference;
109#[cfg(any(test, feature = "testing"))]
110pub mod test_helpers;
111
112pub const HTTP_API_VERSION: &str = "v1";
113pub const HTTP_API_PREFIX: &str = "/v1/";
114/// Default http body limit (64M).
115const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
116
117/// Authorization header
118pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
119
120// TODO(fys): This is a temporary workaround, it will be improved later
121pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
122
123#[derive(Default)]
124pub struct HttpServer {
125    router: StdMutex<Router>,
126    shutdown_tx: Mutex<Option<Sender<()>>>,
127    user_provider: Option<UserProviderRef>,
128
129    // plugins
130    plugins: Plugins,
131
132    // server configs
133    options: HttpOptions,
134    bind_addr: Option<SocketAddr>,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(default)]
139pub struct HttpOptions {
140    pub addr: String,
141
142    #[serde(with = "humantime_serde")]
143    pub timeout: Duration,
144
145    #[serde(skip)]
146    pub disable_dashboard: bool,
147
148    pub body_limit: ReadableSize,
149
150    /// Validation mode while decoding Prometheus remote write requests.
151    pub prom_validation_mode: PromValidationMode,
152
153    pub cors_allowed_origins: Vec<String>,
154
155    pub enable_cors: bool,
156}
157
158#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
159#[serde(rename_all = "snake_case")]
160pub enum PromValidationMode {
161    /// Force UTF8 validation
162    Strict,
163    /// Allow lossy UTF8 strings
164    Lossy,
165    /// Do not validate UTF8 strings.
166    Unchecked,
167}
168
169impl PromValidationMode {
170    /// Decodes provided bytes to [String] with optional UTF-8 validation.
171    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
172        let result = match self {
173            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
174                Ok(s) => s,
175                Err(e) => {
176                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
177                    return Err(DecodeError::new("invalid utf-8"));
178                }
179            },
180            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
181            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
182        };
183        Ok(result)
184    }
185}
186
187impl Default for HttpOptions {
188    fn default() -> Self {
189        Self {
190            addr: "127.0.0.1:4000".to_string(),
191            timeout: Duration::from_secs(0),
192            disable_dashboard: false,
193            body_limit: DEFAULT_BODY_LIMIT,
194            cors_allowed_origins: Vec::new(),
195            enable_cors: true,
196            prom_validation_mode: PromValidationMode::Strict,
197        }
198    }
199}
200
201#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
202pub struct ColumnSchema {
203    name: String,
204    data_type: String,
205}
206
207impl ColumnSchema {
208    pub fn new(name: String, data_type: String) -> ColumnSchema {
209        ColumnSchema { name, data_type }
210    }
211}
212
213#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
214pub struct OutputSchema {
215    column_schemas: Vec<ColumnSchema>,
216}
217
218impl OutputSchema {
219    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
220        OutputSchema {
221            column_schemas: columns,
222        }
223    }
224}
225
226impl From<SchemaRef> for OutputSchema {
227    fn from(schema: SchemaRef) -> OutputSchema {
228        OutputSchema {
229            column_schemas: schema
230                .column_schemas()
231                .iter()
232                .map(|cs| ColumnSchema {
233                    name: cs.name.clone(),
234                    data_type: cs.data_type.name(),
235                })
236                .collect(),
237        }
238    }
239}
240
241#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
242pub struct HttpRecordsOutput {
243    schema: OutputSchema,
244    rows: Vec<Vec<Value>>,
245    // total_rows is equal to rows.len() in most cases,
246    // the Dashboard query result may be truncated, so we need to return the total_rows.
247    #[serde(default)]
248    total_rows: usize,
249
250    // plan level execution metrics
251    #[serde(skip_serializing_if = "HashMap::is_empty")]
252    #[serde(default)]
253    metrics: HashMap<String, Value>,
254}
255
256impl HttpRecordsOutput {
257    pub fn num_rows(&self) -> usize {
258        self.rows.len()
259    }
260
261    pub fn num_cols(&self) -> usize {
262        self.schema.column_schemas.len()
263    }
264
265    pub fn schema(&self) -> &OutputSchema {
266        &self.schema
267    }
268
269    pub fn rows(&self) -> &Vec<Vec<Value>> {
270        &self.rows
271    }
272}
273
274impl HttpRecordsOutput {
275    pub fn try_new(
276        schema: SchemaRef,
277        recordbatches: Vec<RecordBatch>,
278    ) -> std::result::Result<HttpRecordsOutput, Error> {
279        if recordbatches.is_empty() {
280            Ok(HttpRecordsOutput {
281                schema: OutputSchema::from(schema),
282                rows: vec![],
283                total_rows: 0,
284                metrics: Default::default(),
285            })
286        } else {
287            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
288            let mut rows = Vec::with_capacity(num_rows);
289            let schemas = schema.column_schemas();
290            let num_cols = schema.column_schemas().len();
291            rows.resize_with(num_rows, || Vec::with_capacity(num_cols));
292
293            let mut finished_row_cursor = 0;
294            for recordbatch in recordbatches {
295                for (col_idx, col) in recordbatch.columns().iter().enumerate() {
296                    // safety here: schemas length is equal to the number of columns in the recordbatch
297                    let schema = &schemas[col_idx];
298                    for row_idx in 0..recordbatch.num_rows() {
299                        let value = transform_value_ref_to_json_value(col.get_ref(row_idx), schema)
300                            .context(ToJsonSnafu)?;
301                        rows[row_idx + finished_row_cursor].push(value);
302                    }
303                }
304                finished_row_cursor += recordbatch.num_rows();
305            }
306
307            Ok(HttpRecordsOutput {
308                schema: OutputSchema::from(schema),
309                total_rows: rows.len(),
310                rows,
311                metrics: Default::default(),
312            })
313        }
314    }
315}
316
317#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
318#[serde(rename_all = "lowercase")]
319pub enum GreptimeQueryOutput {
320    AffectedRows(usize),
321    Records(HttpRecordsOutput),
322}
323
324/// It allows the results of SQL queries to be presented in different formats.
325#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
326pub enum ResponseFormat {
327    Arrow,
328    // (with_names, with_types)
329    Csv(bool, bool),
330    Table,
331    #[default]
332    GreptimedbV1,
333    InfluxdbV1,
334    Json,
335}
336
337impl ResponseFormat {
338    pub fn parse(s: &str) -> Option<Self> {
339        match s {
340            "arrow" => Some(ResponseFormat::Arrow),
341            "csv" => Some(ResponseFormat::Csv(false, false)),
342            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
343            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
344            "table" => Some(ResponseFormat::Table),
345            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
346            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
347            "json" => Some(ResponseFormat::Json),
348            _ => None,
349        }
350    }
351
352    pub fn as_str(&self) -> &'static str {
353        match self {
354            ResponseFormat::Arrow => "arrow",
355            ResponseFormat::Csv(_, _) => "csv",
356            ResponseFormat::Table => "table",
357            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
358            ResponseFormat::InfluxdbV1 => "influxdb_v1",
359            ResponseFormat::Json => "json",
360        }
361    }
362}
363
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
365pub enum Epoch {
366    Nanosecond,
367    Microsecond,
368    Millisecond,
369    Second,
370}
371
372impl Epoch {
373    pub fn parse(s: &str) -> Option<Epoch> {
374        // Both u and µ indicate microseconds.
375        // epoch = [ns,u,µ,ms,s],
376        // For details, see the Influxdb documents.
377        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
378        match s {
379            "ns" => Some(Epoch::Nanosecond),
380            "u" | "µ" => Some(Epoch::Microsecond),
381            "ms" => Some(Epoch::Millisecond),
382            "s" => Some(Epoch::Second),
383            _ => None, // just returns None for other cases
384        }
385    }
386
387    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
388        match self {
389            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
390            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
391            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
392            Epoch::Second => ts.convert_to(TimeUnit::Second),
393        }
394    }
395}
396
397impl Display for Epoch {
398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        match self {
400            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
401            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
402            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
403            Epoch::Second => write!(f, "Epoch::Second"),
404        }
405    }
406}
407
408#[derive(Serialize, Deserialize, Debug)]
409pub enum HttpResponse {
410    Arrow(ArrowResponse),
411    Csv(CsvResponse),
412    Table(TableResponse),
413    Error(ErrorResponse),
414    GreptimedbV1(GreptimedbV1Response),
415    InfluxdbV1(InfluxdbV1Response),
416    Json(JsonResponse),
417}
418
419impl HttpResponse {
420    pub fn with_execution_time(self, execution_time: u64) -> Self {
421        match self {
422            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
423            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
424            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
425            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
426            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
427            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
428            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
429        }
430    }
431
432    pub fn with_limit(self, limit: usize) -> Self {
433        match self {
434            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
435            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
436            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
437            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
438            _ => self,
439        }
440    }
441}
442
443pub fn process_with_limit(
444    mut outputs: Vec<GreptimeQueryOutput>,
445    limit: usize,
446) -> Vec<GreptimeQueryOutput> {
447    outputs
448        .drain(..)
449        .map(|data| match data {
450            GreptimeQueryOutput::Records(mut records) => {
451                if records.rows.len() > limit {
452                    records.rows.truncate(limit);
453                    records.total_rows = limit;
454                }
455                GreptimeQueryOutput::Records(records)
456            }
457            _ => data,
458        })
459        .collect()
460}
461
462impl IntoResponse for HttpResponse {
463    fn into_response(self) -> Response {
464        match self {
465            HttpResponse::Arrow(resp) => resp.into_response(),
466            HttpResponse::Csv(resp) => resp.into_response(),
467            HttpResponse::Table(resp) => resp.into_response(),
468            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
469            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
470            HttpResponse::Json(resp) => resp.into_response(),
471            HttpResponse::Error(resp) => resp.into_response(),
472        }
473    }
474}
475
476impl From<ArrowResponse> for HttpResponse {
477    fn from(value: ArrowResponse) -> Self {
478        HttpResponse::Arrow(value)
479    }
480}
481
482impl From<CsvResponse> for HttpResponse {
483    fn from(value: CsvResponse) -> Self {
484        HttpResponse::Csv(value)
485    }
486}
487
488impl From<TableResponse> for HttpResponse {
489    fn from(value: TableResponse) -> Self {
490        HttpResponse::Table(value)
491    }
492}
493
494impl From<ErrorResponse> for HttpResponse {
495    fn from(value: ErrorResponse) -> Self {
496        HttpResponse::Error(value)
497    }
498}
499
500impl From<GreptimedbV1Response> for HttpResponse {
501    fn from(value: GreptimedbV1Response) -> Self {
502        HttpResponse::GreptimedbV1(value)
503    }
504}
505
506impl From<InfluxdbV1Response> for HttpResponse {
507    fn from(value: InfluxdbV1Response) -> Self {
508        HttpResponse::InfluxdbV1(value)
509    }
510}
511
512impl From<JsonResponse> for HttpResponse {
513    fn from(value: JsonResponse) -> Self {
514        HttpResponse::Json(value)
515    }
516}
517
518#[derive(Clone)]
519pub struct ApiState {
520    pub sql_handler: ServerSqlQueryHandlerRef,
521}
522
523#[derive(Clone)]
524pub struct GreptimeOptionsConfigState {
525    pub greptime_config_options: String,
526}
527
528#[derive(Default)]
529pub struct HttpServerBuilder {
530    options: HttpOptions,
531    plugins: Plugins,
532    user_provider: Option<UserProviderRef>,
533    router: Router,
534}
535
536impl HttpServerBuilder {
537    pub fn new(options: HttpOptions) -> Self {
538        Self {
539            options,
540            plugins: Plugins::default(),
541            user_provider: None,
542            router: Router::new(),
543        }
544    }
545
546    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
547        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
548
549        Self {
550            router: self
551                .router
552                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
553            ..self
554        }
555    }
556
557    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
558        let logs_router = HttpServer::route_logs(logs_handler);
559
560        Self {
561            router: self
562                .router
563                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
564            ..self
565        }
566    }
567
568    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
569        Self {
570            router: self.router.nest(
571                &format!("/{HTTP_API_VERSION}/opentsdb"),
572                HttpServer::route_opentsdb(handler),
573            ),
574            ..self
575        }
576    }
577
578    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
579        Self {
580            router: self.router.nest(
581                &format!("/{HTTP_API_VERSION}/influxdb"),
582                HttpServer::route_influxdb(handler),
583            ),
584            ..self
585        }
586    }
587
588    pub fn with_prom_handler(
589        self,
590        handler: PromStoreProtocolHandlerRef,
591        pipeline_handler: Option<PipelineHandlerRef>,
592        prom_store_with_metric_engine: bool,
593        prom_validation_mode: PromValidationMode,
594    ) -> Self {
595        let state = PromStoreState {
596            prom_store_handler: handler,
597            pipeline_handler,
598            prom_store_with_metric_engine,
599            prom_validation_mode,
600        };
601
602        Self {
603            router: self.router.nest(
604                &format!("/{HTTP_API_VERSION}/prometheus"),
605                HttpServer::route_prom(state),
606            ),
607            ..self
608        }
609    }
610
611    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
612        Self {
613            router: self.router.nest(
614                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
615                HttpServer::route_prometheus(handler),
616            ),
617            ..self
618        }
619    }
620
621    pub fn with_otlp_handler(self, handler: OpenTelemetryProtocolHandlerRef) -> Self {
622        Self {
623            router: self.router.nest(
624                &format!("/{HTTP_API_VERSION}/otlp"),
625                HttpServer::route_otlp(handler),
626            ),
627            ..self
628        }
629    }
630
631    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
632        Self {
633            user_provider: Some(user_provider),
634            ..self
635        }
636    }
637
638    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
639        Self {
640            router: self.router.merge(HttpServer::route_metrics(handler)),
641            ..self
642        }
643    }
644
645    pub fn with_log_ingest_handler(
646        self,
647        handler: PipelineHandlerRef,
648        validator: Option<LogValidatorRef>,
649        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
650    ) -> Self {
651        let log_state = LogState {
652            log_handler: handler,
653            log_validator: validator,
654            ingest_interceptor,
655        };
656
657        let router = self.router.nest(
658            &format!("/{HTTP_API_VERSION}"),
659            HttpServer::route_pipelines(log_state.clone()),
660        );
661        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
662        let router = router.nest(
663            &format!("/{HTTP_API_VERSION}/events"),
664            #[allow(deprecated)]
665            HttpServer::route_log_deprecated(log_state.clone()),
666        );
667
668        let router = router.nest(
669            &format!("/{HTTP_API_VERSION}/loki"),
670            HttpServer::route_loki(log_state.clone()),
671        );
672
673        let router = router.nest(
674            &format!("/{HTTP_API_VERSION}/elasticsearch"),
675            HttpServer::route_elasticsearch(log_state.clone()),
676        );
677
678        let router = router.nest(
679            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
680            Router::new()
681                .route("/", routing::get(elasticsearch::handle_get_version))
682                .with_state(log_state),
683        );
684
685        Self { router, ..self }
686    }
687
688    pub fn with_plugins(self, plugins: Plugins) -> Self {
689        Self { plugins, ..self }
690    }
691
692    pub fn with_greptime_config_options(self, opts: String) -> Self {
693        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
694            greptime_config_options: opts,
695        });
696
697        Self {
698            router: self.router.merge(config_router),
699            ..self
700        }
701    }
702
703    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
704        Self {
705            router: self.router.nest(
706                &format!("/{HTTP_API_VERSION}/jaeger"),
707                HttpServer::route_jaeger(handler),
708            ),
709            ..self
710        }
711    }
712
713    pub fn with_extra_router(self, router: Router) -> Self {
714        Self {
715            router: self.router.merge(router),
716            ..self
717        }
718    }
719
720    pub fn build(self) -> HttpServer {
721        HttpServer {
722            options: self.options,
723            user_provider: self.user_provider,
724            shutdown_tx: Mutex::new(None),
725            plugins: self.plugins,
726            router: StdMutex::new(self.router),
727            bind_addr: None,
728        }
729    }
730}
731
732impl HttpServer {
733    /// Gets the router and adds necessary root routes (health, status, dashboard).
734    pub fn make_app(&self) -> Router {
735        let mut router = {
736            let router = self.router.lock().unwrap();
737            router.clone()
738        };
739
740        router = router
741            .route(
742                "/health",
743                routing::get(handler::health).post(handler::health),
744            )
745            .route(
746                &format!("/{HTTP_API_VERSION}/health"),
747                routing::get(handler::health).post(handler::health),
748            )
749            .route(
750                "/ready",
751                routing::get(handler::health).post(handler::health),
752            );
753
754        router = router.route("/status", routing::get(handler::status));
755
756        #[cfg(feature = "dashboard")]
757        {
758            if !self.options.disable_dashboard {
759                info!("Enable dashboard service at '/dashboard'");
760                // redirect /dashboard to /dashboard/
761                router = router.route(
762                    "/dashboard",
763                    routing::get(|uri: axum::http::uri::Uri| async move {
764                        let path = uri.path();
765                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
766
767                        let new_uri = format!("{}/{}", path, query);
768                        axum::response::Redirect::permanent(&new_uri)
769                    }),
770                );
771
772                // "/dashboard" and "/dashboard/" are two different paths in Axum.
773                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
774                // So we explicitly route "/dashboard/" here.
775                router = router
776                    .route(
777                        "/dashboard/",
778                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
779                    )
780                    .route(
781                        "/dashboard/{*x}",
782                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
783                    );
784            }
785        }
786
787        // Add a layer to collect HTTP metrics for axum.
788        router = router.route_layer(middleware::from_fn(http_metrics_layer));
789
790        router
791    }
792
793    /// Attaches middlewares and debug routes to the router.
794    /// Callers should call this method after [HttpServer::make_app()].
795    pub fn build(&self, router: Router) -> Result<Router> {
796        let timeout_layer = if self.options.timeout != Duration::default() {
797            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
798        } else {
799            info!("HTTP server timeout is disabled");
800            None
801        };
802        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
803            Some(
804                ServiceBuilder::new()
805                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
806            )
807        } else {
808            info!("HTTP server body limit is disabled");
809            None
810        };
811        let cors_layer = if self.options.enable_cors {
812            Some(
813                CorsLayer::new()
814                    .allow_methods([
815                        Method::GET,
816                        Method::POST,
817                        Method::PUT,
818                        Method::DELETE,
819                        Method::HEAD,
820                    ])
821                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
822                        AllowOrigin::from(Any)
823                    } else {
824                        AllowOrigin::from(
825                            self.options
826                                .cors_allowed_origins
827                                .iter()
828                                .map(|s| {
829                                    HeaderValue::from_str(s.as_str())
830                                        .context(InvalidHeaderValueSnafu)
831                                })
832                                .collect::<Result<Vec<HeaderValue>>>()?,
833                        )
834                    })
835                    .allow_headers(Any),
836            )
837        } else {
838            info!("HTTP server cross-origin is disabled");
839            None
840        };
841
842        Ok(router
843            // middlewares
844            .layer(
845                ServiceBuilder::new()
846                    // disable on failure tracing. because printing out isn't very helpful,
847                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
848                    .layer(TraceLayer::new_for_http().on_failure(()))
849                    .option_layer(cors_layer)
850                    .option_layer(timeout_layer)
851                    .option_layer(body_limit_layer)
852                    // auth layer
853                    .layer(middleware::from_fn_with_state(
854                        AuthState::new(self.user_provider.clone()),
855                        authorize::check_http_auth,
856                    ))
857                    .layer(middleware::from_fn(hints::extract_hints))
858                    .layer(middleware::from_fn(
859                        read_preference::extract_read_preference,
860                    )),
861            )
862            // Handlers for debug, we don't expect a timeout.
863            .nest(
864                "/debug",
865                Router::new()
866                    // handler for changing log level dynamically
867                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
868                    .nest(
869                        "/prof",
870                        Router::new()
871                            .route("/cpu", routing::post(pprof::pprof_handler))
872                            .route("/mem", routing::post(mem_prof::mem_prof_handler)),
873                    ),
874            ))
875    }
876
877    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
878        Router::new()
879            .route("/metrics", routing::get(handler::metrics))
880            .with_state(metrics_handler)
881    }
882
883    fn route_loki<S>(log_state: LogState) -> Router<S> {
884        Router::new()
885            .route("/api/v1/push", routing::post(loki::loki_ingest))
886            .layer(
887                ServiceBuilder::new()
888                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
889            )
890            .with_state(log_state)
891    }
892
893    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
894        Router::new()
895            // Return fake responsefor HEAD '/' request.
896            .route(
897                "/",
898                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
899            )
900            // Return fake response for Elasticsearch version request.
901            .route("/", routing::get(elasticsearch::handle_get_version))
902            // Return fake response for Elasticsearch license request.
903            .route("/_license", routing::get(elasticsearch::handle_get_license))
904            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
905            .route(
906                "/{index}/_bulk",
907                routing::post(elasticsearch::handle_bulk_api_with_index),
908            )
909            // Return fake response for Elasticsearch ilm request.
910            .route(
911                "/_ilm/policy/{*path}",
912                routing::any((
913                    HttpStatusCode::OK,
914                    elasticsearch::elasticsearch_headers(),
915                    axum::Json(serde_json::json!({})),
916                )),
917            )
918            // Return fake response for Elasticsearch index template request.
919            .route(
920                "/_index_template/{*path}",
921                routing::any((
922                    HttpStatusCode::OK,
923                    elasticsearch::elasticsearch_headers(),
924                    axum::Json(serde_json::json!({})),
925                )),
926            )
927            // Return fake response for Elasticsearch ingest pipeline request.
928            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
929            .route(
930                "/_ingest/{*path}",
931                routing::any((
932                    HttpStatusCode::OK,
933                    elasticsearch::elasticsearch_headers(),
934                    axum::Json(serde_json::json!({})),
935                )),
936            )
937            // Return fake response for Elasticsearch nodes discovery request.
938            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
939            .route(
940                "/_nodes/{*path}",
941                routing::any((
942                    HttpStatusCode::OK,
943                    elasticsearch::elasticsearch_headers(),
944                    axum::Json(serde_json::json!({})),
945                )),
946            )
947            // Return fake response for Logstash APIs requests.
948            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
949            .route(
950                "/logstash/{*path}",
951                routing::any((
952                    HttpStatusCode::OK,
953                    elasticsearch::elasticsearch_headers(),
954                    axum::Json(serde_json::json!({})),
955                )),
956            )
957            .route(
958                "/_logstash/{*path}",
959                routing::any((
960                    HttpStatusCode::OK,
961                    elasticsearch::elasticsearch_headers(),
962                    axum::Json(serde_json::json!({})),
963                )),
964            )
965            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
966            .with_state(log_state)
967    }
968
969    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
970    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
971        Router::new()
972            .route("/logs", routing::post(event::log_ingester))
973            .route(
974                "/pipelines/{pipeline_name}",
975                routing::get(event::query_pipeline),
976            )
977            .route(
978                "/pipelines/{pipeline_name}",
979                routing::post(event::add_pipeline),
980            )
981            .route(
982                "/pipelines/{pipeline_name}",
983                routing::delete(event::delete_pipeline),
984            )
985            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
986            .layer(
987                ServiceBuilder::new()
988                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
989            )
990            .with_state(log_state)
991    }
992
993    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
994        Router::new()
995            .route("/ingest", routing::post(event::log_ingester))
996            .route(
997                "/pipelines/{pipeline_name}",
998                routing::get(event::query_pipeline),
999            )
1000            .route(
1001                "/pipelines/{pipeline_name}",
1002                routing::post(event::add_pipeline),
1003            )
1004            .route(
1005                "/pipelines/{pipeline_name}",
1006                routing::delete(event::delete_pipeline),
1007            )
1008            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1009            .layer(
1010                ServiceBuilder::new()
1011                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1012            )
1013            .with_state(log_state)
1014    }
1015
1016    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1017        Router::new()
1018            .route("/sql", routing::get(handler::sql).post(handler::sql))
1019            .route(
1020                "/sql/parse",
1021                routing::get(handler::sql_parse).post(handler::sql_parse),
1022            )
1023            .route(
1024                "/promql",
1025                routing::get(handler::promql).post(handler::promql),
1026            )
1027            .with_state(api_state)
1028    }
1029
1030    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1031        Router::new()
1032            .route("/logs", routing::get(logs::logs).post(logs::logs))
1033            .with_state(log_handler)
1034    }
1035
1036    /// Route Prometheus [HTTP API].
1037    ///
1038    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1039    fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1040        Router::new()
1041            .route(
1042                "/format_query",
1043                routing::post(format_query).get(format_query),
1044            )
1045            .route("/status/buildinfo", routing::get(build_info_query))
1046            .route("/query", routing::post(instant_query).get(instant_query))
1047            .route("/query_range", routing::post(range_query).get(range_query))
1048            .route("/labels", routing::post(labels_query).get(labels_query))
1049            .route("/series", routing::post(series_query).get(series_query))
1050            .route("/parse_query", routing::post(parse_query).get(parse_query))
1051            .route(
1052                "/label/{label_name}/values",
1053                routing::get(label_values_query),
1054            )
1055            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1056            .with_state(prometheus_handler)
1057    }
1058
1059    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1060    /// called `prom_store`.
1061    ///
1062    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1063    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1064    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1065        Router::new()
1066            .route("/read", routing::post(prom_store::remote_read))
1067            .route("/write", routing::post(prom_store::remote_write))
1068            .with_state(state)
1069    }
1070
1071    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1072        Router::new()
1073            .route("/write", routing::post(influxdb_write_v1))
1074            .route("/api/v2/write", routing::post(influxdb_write_v2))
1075            .layer(
1076                ServiceBuilder::new()
1077                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1078            )
1079            .route("/ping", routing::get(influxdb_ping))
1080            .route("/health", routing::get(influxdb_health))
1081            .with_state(influxdb_handler)
1082    }
1083
1084    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1085        Router::new()
1086            .route("/api/put", routing::post(opentsdb::put))
1087            .with_state(opentsdb_handler)
1088    }
1089
1090    fn route_otlp<S>(otlp_handler: OpenTelemetryProtocolHandlerRef) -> Router<S> {
1091        Router::new()
1092            .route("/v1/metrics", routing::post(otlp::metrics))
1093            .route("/v1/traces", routing::post(otlp::traces))
1094            .route("/v1/logs", routing::post(otlp::logs))
1095            .layer(
1096                ServiceBuilder::new()
1097                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1098            )
1099            .with_state(otlp_handler)
1100    }
1101
1102    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1103        Router::new()
1104            .route("/config", routing::get(handler::config))
1105            .with_state(state)
1106    }
1107
1108    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1109        Router::new()
1110            .route("/api/services", routing::get(jaeger::handle_get_services))
1111            .route(
1112                "/api/services/{service_name}/operations",
1113                routing::get(jaeger::handle_get_operations_by_service),
1114            )
1115            .route(
1116                "/api/operations",
1117                routing::get(jaeger::handle_get_operations),
1118            )
1119            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1120            .route(
1121                "/api/traces/{trace_id}",
1122                routing::get(jaeger::handle_get_trace),
1123            )
1124            .with_state(handler)
1125    }
1126}
1127
1128pub const HTTP_SERVER: &str = "HTTP_SERVER";
1129
1130#[async_trait]
1131impl Server for HttpServer {
1132    async fn shutdown(&self) -> Result<()> {
1133        let mut shutdown_tx = self.shutdown_tx.lock().await;
1134        if let Some(tx) = shutdown_tx.take() {
1135            if tx.send(()).is_err() {
1136                info!("Receiver dropped, the HTTP server has already existed");
1137            }
1138        }
1139        info!("Shutdown HTTP server");
1140
1141        Ok(())
1142    }
1143
1144    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1145        let (tx, rx) = oneshot::channel();
1146        let serve = {
1147            let mut shutdown_tx = self.shutdown_tx.lock().await;
1148            ensure!(
1149                shutdown_tx.is_none(),
1150                AlreadyStartedSnafu { server: "HTTP" }
1151            );
1152
1153            let mut app = self.make_app();
1154            if let Some(configurator) = self.plugins.get::<ConfiguratorRef>() {
1155                app = configurator.config_http(app);
1156            }
1157            let app = self.build(app)?;
1158            let listener = tokio::net::TcpListener::bind(listening)
1159                .await
1160                .context(AddressBindSnafu { addr: listening })?
1161                .tap_io(|tcp_stream| {
1162                    if let Err(e) = tcp_stream.set_nodelay(true) {
1163                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1164                    }
1165                });
1166            let serve = axum::serve(listener, app.into_make_service());
1167
1168            // FIXME(yingwen): Support keepalive.
1169            // See:
1170            // - https://github.com/tokio-rs/axum/discussions/2939
1171            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1172            // let server = axum::Server::try_bind(&listening)
1173            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1174            //     .tcp_nodelay(true)
1175            //     // Enable TCP keepalive to close the dangling established connections.
1176            //     // It's configured to let the keepalive probes first send after the connection sits
1177            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1178            //     // So the connection will be closed after roughly 1 hour.
1179            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1180            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1181            //     .tcp_keepalive_retries(Some(6))
1182            //     .serve(app.into_make_service());
1183
1184            *shutdown_tx = Some(tx);
1185
1186            serve
1187        };
1188        let listening = serve.local_addr().context(InternalIoSnafu)?;
1189        info!("HTTP server is bound to {}", listening);
1190
1191        common_runtime::spawn_global(async move {
1192            if let Err(e) = serve
1193                .with_graceful_shutdown(rx.map(drop))
1194                .await
1195                .context(InternalIoSnafu)
1196            {
1197                error!(e; "Failed to shutdown http server");
1198            }
1199        });
1200
1201        self.bind_addr = Some(listening);
1202        Ok(())
1203    }
1204
1205    fn name(&self) -> &str {
1206        HTTP_SERVER
1207    }
1208
1209    fn bind_addr(&self) -> Option<SocketAddr> {
1210        self.bind_addr
1211    }
1212}
1213
1214#[cfg(test)]
1215mod test {
1216    use std::future::pending;
1217    use std::io::Cursor;
1218    use std::sync::Arc;
1219
1220    use arrow_ipc::reader::FileReader;
1221    use arrow_schema::DataType;
1222    use axum::handler::Handler;
1223    use axum::http::StatusCode;
1224    use axum::routing::get;
1225    use common_query::Output;
1226    use common_recordbatch::RecordBatches;
1227    use datafusion_expr::LogicalPlan;
1228    use datatypes::prelude::*;
1229    use datatypes::schema::{ColumnSchema, Schema};
1230    use datatypes::vectors::{StringVector, UInt32Vector};
1231    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1232    use query::parser::PromQuery;
1233    use query::query_engine::DescribeResult;
1234    use session::context::QueryContextRef;
1235    use tokio::sync::mpsc;
1236    use tokio::time::Instant;
1237
1238    use super::*;
1239    use crate::error::Error;
1240    use crate::http::test_helpers::TestClient;
1241    use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
1242
1243    struct DummyInstance {
1244        _tx: mpsc::Sender<(String, Vec<u8>)>,
1245    }
1246
1247    #[async_trait]
1248    impl SqlQueryHandler for DummyInstance {
1249        type Error = Error;
1250
1251        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1252            unimplemented!()
1253        }
1254
1255        async fn do_promql_query(
1256            &self,
1257            _: &PromQuery,
1258            _: QueryContextRef,
1259        ) -> Vec<std::result::Result<Output, Self::Error>> {
1260            unimplemented!()
1261        }
1262
1263        async fn do_exec_plan(
1264            &self,
1265            _plan: LogicalPlan,
1266            _query_ctx: QueryContextRef,
1267        ) -> std::result::Result<Output, Self::Error> {
1268            unimplemented!()
1269        }
1270
1271        async fn do_describe(
1272            &self,
1273            _stmt: sql::statements::statement::Statement,
1274            _query_ctx: QueryContextRef,
1275        ) -> Result<Option<DescribeResult>> {
1276            unimplemented!()
1277        }
1278
1279        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1280            Ok(true)
1281        }
1282    }
1283
1284    fn timeout() -> DynamicTimeoutLayer {
1285        DynamicTimeoutLayer::new(Duration::from_millis(10))
1286    }
1287
1288    async fn forever() {
1289        pending().await
1290    }
1291
1292    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1293        make_test_app_custom(tx, HttpOptions::default())
1294    }
1295
1296    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1297        let instance = Arc::new(DummyInstance { _tx: tx });
1298        let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
1299        let server = HttpServerBuilder::new(options)
1300            .with_sql_handler(sql_instance)
1301            .build();
1302        server.build(server.make_app()).unwrap().route(
1303            "/test/timeout",
1304            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1305        )
1306    }
1307
1308    #[tokio::test]
1309    pub async fn test_cors() {
1310        // cors is on by default
1311        let (tx, _rx) = mpsc::channel(100);
1312        let app = make_test_app(tx);
1313        let client = TestClient::new(app).await;
1314
1315        let res = client.get("/health").send().await;
1316
1317        assert_eq!(res.status(), StatusCode::OK);
1318        assert_eq!(
1319            res.headers()
1320                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1321                .expect("expect cors header origin"),
1322            "*"
1323        );
1324
1325        let res = client.get("/v1/health").send().await;
1326
1327        assert_eq!(res.status(), StatusCode::OK);
1328        assert_eq!(
1329            res.headers()
1330                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1331                .expect("expect cors header origin"),
1332            "*"
1333        );
1334
1335        let res = client
1336            .options("/health")
1337            .header("Access-Control-Request-Headers", "x-greptime-auth")
1338            .header("Access-Control-Request-Method", "DELETE")
1339            .header("Origin", "https://example.com")
1340            .send()
1341            .await;
1342        assert_eq!(res.status(), StatusCode::OK);
1343        assert_eq!(
1344            res.headers()
1345                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1346                .expect("expect cors header origin"),
1347            "*"
1348        );
1349        assert_eq!(
1350            res.headers()
1351                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1352                .expect("expect cors header headers"),
1353            "*"
1354        );
1355        assert_eq!(
1356            res.headers()
1357                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1358                .expect("expect cors header methods"),
1359            "GET,POST,PUT,DELETE,HEAD"
1360        );
1361    }
1362
1363    #[tokio::test]
1364    pub async fn test_cors_custom_origins() {
1365        // cors is on by default
1366        let (tx, _rx) = mpsc::channel(100);
1367        let origin = "https://example.com";
1368
1369        let options = HttpOptions {
1370            cors_allowed_origins: vec![origin.to_string()],
1371            ..Default::default()
1372        };
1373
1374        let app = make_test_app_custom(tx, options);
1375        let client = TestClient::new(app).await;
1376
1377        let res = client.get("/health").header("Origin", origin).send().await;
1378
1379        assert_eq!(res.status(), StatusCode::OK);
1380        assert_eq!(
1381            res.headers()
1382                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1383                .expect("expect cors header origin"),
1384            origin
1385        );
1386
1387        let res = client
1388            .get("/health")
1389            .header("Origin", "https://notallowed.com")
1390            .send()
1391            .await;
1392
1393        assert_eq!(res.status(), StatusCode::OK);
1394        assert!(!res
1395            .headers()
1396            .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN));
1397    }
1398
1399    #[tokio::test]
1400    pub async fn test_cors_disabled() {
1401        // cors is on by default
1402        let (tx, _rx) = mpsc::channel(100);
1403
1404        let options = HttpOptions {
1405            enable_cors: false,
1406            ..Default::default()
1407        };
1408
1409        let app = make_test_app_custom(tx, options);
1410        let client = TestClient::new(app).await;
1411
1412        let res = client.get("/health").send().await;
1413
1414        assert_eq!(res.status(), StatusCode::OK);
1415        assert!(!res
1416            .headers()
1417            .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN));
1418    }
1419
1420    #[test]
1421    fn test_http_options_default() {
1422        let default = HttpOptions::default();
1423        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1424        assert_eq!(Duration::from_secs(0), default.timeout)
1425    }
1426
1427    #[tokio::test]
1428    async fn test_http_server_request_timeout() {
1429        common_telemetry::init_default_ut_logging();
1430
1431        let (tx, _rx) = mpsc::channel(100);
1432        let app = make_test_app(tx);
1433        let client = TestClient::new(app).await;
1434        let res = client.get("/test/timeout").send().await;
1435        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1436
1437        let now = Instant::now();
1438        let res = client
1439            .get("/test/timeout")
1440            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1441            .send()
1442            .await;
1443        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1444        let elapsed = now.elapsed();
1445        assert!(elapsed > Duration::from_millis(15));
1446
1447        tokio::time::timeout(
1448            Duration::from_millis(15),
1449            client
1450                .get("/test/timeout")
1451                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1452                .send(),
1453        )
1454        .await
1455        .unwrap_err();
1456
1457        tokio::time::timeout(
1458            Duration::from_millis(15),
1459            client
1460                .get("/test/timeout")
1461                .header(
1462                    GREPTIME_DB_HEADER_TIMEOUT,
1463                    humantime::format_duration(Duration::default()).to_string(),
1464                )
1465                .send(),
1466        )
1467        .await
1468        .unwrap_err();
1469    }
1470
1471    #[tokio::test]
1472    async fn test_schema_for_empty_response() {
1473        let column_schemas = vec![
1474            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1475            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1476        ];
1477        let schema = Arc::new(Schema::new(column_schemas));
1478
1479        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1480        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1481
1482        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1483        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1484            let json_output = &json_resp.output[0];
1485            if let GreptimeQueryOutput::Records(r) = json_output {
1486                assert_eq!(r.num_rows(), 0);
1487                assert_eq!(r.num_cols(), 2);
1488                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1489                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1490            } else {
1491                panic!("invalid output type");
1492            }
1493        } else {
1494            panic!("invalid format")
1495        }
1496    }
1497
1498    #[tokio::test]
1499    async fn test_recordbatches_conversion() {
1500        let column_schemas = vec![
1501            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1502            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1503        ];
1504        let schema = Arc::new(Schema::new(column_schemas));
1505        let columns: Vec<VectorRef> = vec![
1506            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1507            Arc::new(StringVector::from(vec![
1508                None,
1509                Some("hello"),
1510                Some("greptime"),
1511                None,
1512            ])),
1513        ];
1514        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1515
1516        for format in [
1517            ResponseFormat::GreptimedbV1,
1518            ResponseFormat::InfluxdbV1,
1519            ResponseFormat::Csv(true, true),
1520            ResponseFormat::Table,
1521            ResponseFormat::Arrow,
1522            ResponseFormat::Json,
1523        ] {
1524            let recordbatches =
1525                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1526            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1527            let json_resp = match format {
1528                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1529                ResponseFormat::Csv(with_names, with_types) => {
1530                    CsvResponse::from_output(outputs, with_names, with_types).await
1531                }
1532                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1533                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1534                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1535                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1536            };
1537
1538            match json_resp {
1539                HttpResponse::GreptimedbV1(resp) => {
1540                    let json_output = &resp.output[0];
1541                    if let GreptimeQueryOutput::Records(r) = json_output {
1542                        assert_eq!(r.num_rows(), 4);
1543                        assert_eq!(r.num_cols(), 2);
1544                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1545                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1546                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1547                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1548                    } else {
1549                        panic!("invalid output type");
1550                    }
1551                }
1552                HttpResponse::InfluxdbV1(resp) => {
1553                    let json_output = &resp.results()[0];
1554                    assert_eq!(json_output.num_rows(), 4);
1555                    assert_eq!(json_output.num_cols(), 2);
1556                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1557                    assert_eq!(
1558                        json_output.series[0].values[0][0],
1559                        serde_json::Value::from(1)
1560                    );
1561                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1562                }
1563                HttpResponse::Csv(resp) => {
1564                    let output = &resp.output()[0];
1565                    if let GreptimeQueryOutput::Records(r) = output {
1566                        assert_eq!(r.num_rows(), 4);
1567                        assert_eq!(r.num_cols(), 2);
1568                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1569                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1570                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1571                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1572                    } else {
1573                        panic!("invalid output type");
1574                    }
1575                }
1576
1577                HttpResponse::Table(resp) => {
1578                    let output = &resp.output()[0];
1579                    if let GreptimeQueryOutput::Records(r) = output {
1580                        assert_eq!(r.num_rows(), 4);
1581                        assert_eq!(r.num_cols(), 2);
1582                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1583                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1584                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1585                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1586                    } else {
1587                        panic!("invalid output type");
1588                    }
1589                }
1590
1591                HttpResponse::Arrow(resp) => {
1592                    let output = resp.data;
1593                    let mut reader =
1594                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1595                    let schema = reader.schema();
1596                    assert_eq!(schema.fields[0].name(), "numbers");
1597                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1598                    assert_eq!(schema.fields[1].name(), "strings");
1599                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1600
1601                    let rb = reader.next().unwrap().expect("read record batch failed");
1602                    assert_eq!(rb.num_columns(), 2);
1603                    assert_eq!(rb.num_rows(), 4);
1604                }
1605
1606                HttpResponse::Json(resp) => {
1607                    let output = &resp.output()[0];
1608                    if let GreptimeQueryOutput::Records(r) = output {
1609                        assert_eq!(r.num_rows(), 4);
1610                        assert_eq!(r.num_cols(), 2);
1611                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1612                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1613                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1614                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1615                    } else {
1616                        panic!("invalid output type");
1617                    }
1618                }
1619
1620                HttpResponse::Error(err) => unreachable!("{err:?}"),
1621            }
1622        }
1623    }
1624
1625    #[test]
1626    fn test_response_format_misc() {
1627        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1628        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1629        assert_eq!(
1630            ResponseFormat::parse("csv"),
1631            Some(ResponseFormat::Csv(false, false))
1632        );
1633        assert_eq!(
1634            ResponseFormat::parse("csvwithnames"),
1635            Some(ResponseFormat::Csv(true, false))
1636        );
1637        assert_eq!(
1638            ResponseFormat::parse("csvwithnamesandtypes"),
1639            Some(ResponseFormat::Csv(true, true))
1640        );
1641        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1642        assert_eq!(
1643            ResponseFormat::parse("greptimedb_v1"),
1644            Some(ResponseFormat::GreptimedbV1)
1645        );
1646        assert_eq!(
1647            ResponseFormat::parse("influxdb_v1"),
1648            Some(ResponseFormat::InfluxdbV1)
1649        );
1650        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1651
1652        // invalid formats
1653        assert_eq!(ResponseFormat::parse("invalid"), None);
1654        assert_eq!(ResponseFormat::parse(""), None);
1655        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1656
1657        // as str
1658        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1659        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1660        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1661        assert_eq!(ResponseFormat::Table.as_str(), "table");
1662        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1663        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1664        assert_eq!(ResponseFormat::Json.as_str(), "json");
1665    }
1666}