Tidy up engine crate (#1296)

<!-- The PR description should answer 2 important questions: -->

### What

Needed to reuse some stuff when putting JSONAPI into multitenant, and
everything is a big tangled mess, so thought it was finally time to
clean up shop. Moves a bunch of stuff from `/bin` into the library in
`src` and splits it into files.

Functional no-op.

V3_GIT_ORIGIN_REV_ID: 87406e3fb63a1f90347782abeda3d4db14386270
This commit is contained in:
Daniel Harvey 2024-10-29 16:47:30 +00:00 committed by hasura-bot
parent d5f70fd56a
commit 5790c088b1
11 changed files with 699 additions and 648 deletions

View File

@ -1,56 +1,19 @@
use futures_util::FutureExt;
use hasura_authn::{authenticate, resolve_auth_config, AuthConfig};
use json_api::create_json_api_router;
use std::fmt::Display;
use std::hash;
use std::hash::{Hash, Hasher};
use clap::Parser;
use serde::Serialize;
use std::net;
use std::path::PathBuf;
use std::sync::Arc;
use axum::{
extract::{DefaultBodyLimit, State},
http::{HeaderMap, Request},
middleware::Next,
response::{Html, IntoResponse},
routing::{get, post},
Extension, Json, Router,
};
use axum_core::body::Body;
use base64::engine::Engine;
use clap::Parser;
use http_body_util::BodyExt;
use metadata_resolve::LifecyclePluginConfigs;
use pre_parse_plugin::execute::pre_parse_plugins_handler;
use pre_response_plugin::execute::pre_response_plugins_handler;
use reqwest::header::CONTENT_TYPE;
use serde::Serialize;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use engine::{
internal_flags::{resolve_unstable_features, UnstableFeature},
VERSION,
EngineRouter, StartupError, VERSION,
};
use execute::HttpContext;
use graphql_schema::GDS;
use hasura_authn_core::Session;
use lang_graphql as gql;
use tracing_util::{
add_event_on_active_span, set_attribute_on_active_span, set_status_on_current_span,
ErrorVisibility, SpanVisibility, TraceableError, TraceableHttpResponse,
};
mod cors;
mod json_api;
use tracing_util::{add_event_on_active_span, set_attribute_on_active_span, SpanVisibility};
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
const DEFAULT_PORT: u16 = 3000;
const MB: usize = 1_048_576;
#[allow(clippy::struct_excessive_bools)] // booleans are pretty useful here
#[derive(Parser, Serialize)]
#[command(version = VERSION)]
@ -109,19 +72,6 @@ struct ServerOptions {
export_traces_stdout: bool,
}
#[derive(Clone)] // Cheap to clone as heavy fields are wrapped in `Arc`
pub struct EngineState {
expose_internal_errors: execute::ExposeInternalErrors,
http_context: HttpContext,
graphql_state: Arc<gql::schema::Schema<GDS>>,
resolved_metadata: Arc<metadata_resolve::Metadata>,
jsonapi_catalog: Arc<jsonapi::Catalog>,
auth_config: Arc<AuthConfig>,
sql_context: Arc<sql::catalog::Catalog>,
plugin_configs: Arc<LifecyclePluginConfigs>,
graphql_websocket_server: Arc<graphql_ws::WebSocketServer<graphql_ws::NoOpWebSocketMetrics>>,
}
#[tokio::main]
#[allow(clippy::print_stdout)]
async fn main() {
@ -156,186 +106,6 @@ async fn main() {
tracing_util::shutdown_tracer();
}
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
enum StartupError {
#[error("could not read the auth config - {0}")]
ReadAuth(anyhow::Error),
#[error("failed to build engine state - {0}")]
ReadSchema(anyhow::Error),
}
impl TraceableError for StartupError {
fn visibility(&self) -> tracing_util::ErrorVisibility {
ErrorVisibility::User
}
}
/// The main router for the engine.
struct EngineRouter {
/// The base router for the engine.
/// Contains /, /graphql, /v1/explain and /health routes.
base_router: Router,
/// The metadata routes for the introspection metadata file.
/// Contains /metadata and /metadata-hash routes.
metadata_routes: Option<Router>,
/// Routes for the SQL interface
sql_routes: Option<Router>,
/// Routes for the JSON:API interface
jsonapi_routes: Option<Router>,
/// The CORS layer for the engine.
cors_layer: Option<CorsLayer>,
}
impl EngineRouter {
fn new(state: EngineState) -> Self {
let graphql_ws_route = Router::new()
.route("/graphql", get(handle_websocket_request))
.layer(axum::middleware::from_fn(
graphql_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `graphql_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let graphql_route = Router::new()
.route("/graphql", post(handle_request))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
plugins_middleware,
))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(
graphql_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `graphql_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let explain_route = Router::new()
.route("/v1/explain", post(handle_explain_request))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(
explain_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `explain_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state);
let health_route = Router::new().route("/health", get(handle_health));
let base_routes = Router::new()
// serve graphiql at root
.route("/", get(graphiql))
// The '/graphql' route
.merge(graphql_route)
// The '/graphql' route for websocket
.merge(graphql_ws_route)
// The '/v1/explain' route
.merge(explain_route)
// The '/health' route
.merge(health_route)
// Set request payload limit to 10 MB
.layer(DefaultBodyLimit::max(10 * MB));
Self {
base_router: base_routes,
metadata_routes: None,
sql_routes: None,
jsonapi_routes: None,
cors_layer: None,
}
}
/// Serve the introspection metadata file and its hash at `/metadata` and `/metadata-hash` respectively.
/// This is a temporary workaround to enable the console to interact with an engine process running locally.
async fn add_metadata_routes(
&mut self,
introspection_metadata_path: &PathBuf,
) -> Result<(), StartupError> {
let file_contents = tokio::fs::read_to_string(introspection_metadata_path)
.await
.map_err(|err| StartupError::ReadSchema(err.into()))?;
let mut hasher = hash::DefaultHasher::new();
file_contents.hash(&mut hasher);
let hash = hasher.finish();
let base64_hash = base64::engine::general_purpose::STANDARD.encode(hash.to_ne_bytes());
let metadata_routes = Router::new()
.route("/metadata", get(|| async { file_contents }))
.route("/metadata-hash", get(|| async { base64_hash }));
self.metadata_routes = Some(metadata_routes);
Ok(())
}
fn add_sql_route(&mut self, state: EngineState) {
let sql_routes = Router::new()
.route("/v1/sql", post(handle_sql_request))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(sql_request_tracing_middleware))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `explain_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state);
self.sql_routes = Some(sql_routes);
}
fn add_jsonapi_route(&mut self, state: EngineState) {
let jsonapi_routes = create_json_api_router(state);
self.jsonapi_routes = Some(jsonapi_routes);
}
fn add_cors_layer(&mut self, allow_origin: &[String]) {
self.cors_layer = Some(cors::build_cors_layer(allow_origin));
}
fn into_make_service(self) -> axum::routing::IntoMakeService<Router> {
let mut app = self.base_router;
// Merge the metadata routes if they exist.
if let Some(sql_routes) = self.sql_routes {
app = app.merge(sql_routes);
}
if let Some(jsonapi_routes) = self.jsonapi_routes {
app = app.merge(jsonapi_routes);
}
// Merge the metadata routes if they exist.
if let Some(metadata_routes) = self.metadata_routes {
app = app.merge(metadata_routes);
}
// Add the CORS layer if it exists.
if let Some(cors_layer) = self.cors_layer {
// It is important that this layer is added last, since it only affects
// the layers that precede it.
app = app.layer(cors_layer);
}
app.into_make_service()
}
}
#[allow(clippy::print_stdout)]
async fn start_engine(server: &ServerOptions) -> Result<(), StartupError> {
let metadata_resolve_configuration = metadata_resolve::configuration::Configuration {
@ -348,7 +118,7 @@ async fn start_engine(server: &ServerOptions) -> Result<(), StartupError> {
execute::ExposeInternalErrors::Censor
};
let state = build_state(
let state = engine::build_state(
expose_internal_errors,
&server.authn_config_path,
&server.metadata_path,
@ -407,411 +177,3 @@ async fn start_engine(server: &ServerOptions) -> Result<(), StartupError> {
Ok(())
}
/// Health check endpoint
async fn handle_health() -> reqwest::StatusCode {
reqwest::StatusCode::OK
}
/// Middleware to start tracing of the `/graphql` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the GraphQL request.
async fn graphql_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
use tracing_util::*;
let tracer = global_tracer();
let path = "/graphql";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
set_attribute_on_active_span(AttributeVisibility::Internal, "version", VERSION);
Box::pin(async move {
let mut response = next.run(request).await;
get_text_map_propagator(|propagator| {
propagator.inject_context(
&Context::current(),
&mut HeaderInjector(response.headers_mut()),
);
});
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// Middleware to start tracing of the `/v1/explain` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the GraphQL request.
async fn explain_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let path = "/v1/explain";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
Box::pin(async move {
let response = next.run(request).await;
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// Middleware to start tracing of the `/v1/sql` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the SQL request.
async fn sql_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let path = "/v1/sql";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
Box::pin(async move {
let response = next.run(request).await;
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// This middleware authenticates the incoming GraphQL request according to the
/// authentication configuration present in the `auth_config` of `EngineState`. The
/// result of the authentication is `hasura-authn-core::Identity`, which is then
/// made available to the GraphQL request handler.
pub async fn authentication_middleware<'a>(
State(engine_state): State<EngineState>,
headers_map: HeaderMap,
mut request: Request<Body>,
next: Next,
) -> axum::response::Result<axum::response::Response> {
let tracer = tracing_util::global_tracer();
let resolved_identity = tracer
.in_span_async(
"authentication_middleware",
"Authentication middleware",
SpanVisibility::Internal,
|| {
Box::pin(authenticate(
&headers_map,
&engine_state.http_context.client,
&engine_state.auth_config,
))
},
)
.await?;
request.extensions_mut().insert(resolved_identity);
Ok(next.run(request).await)
}
async fn graphiql() -> Html<&'static str> {
Html(include_str!("index.html"))
}
async fn handle_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<gql::http::RawRequest>,
) -> gql::http::Response {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_request",
"Handle request",
SpanVisibility::User,
|| {
{
Box::pin(
graphql_frontend::execute_query(
state.expose_internal_errors,
&state.http_context,
&state.graphql_state,
&session,
&headers,
request,
None,
)
.map(|(_operation_type, graphql_response)| graphql_response),
)
}
},
)
.await;
// Set the span as error if the response contains an error
// NOTE: Ideally, we should mark the root span as error in `graphql_request_tracing_middleware` function,
// the tracing middleware, where the span is initialized. It is possible by completing the implementation
// of `Traceable` trait for `AxumResponse` struct. The said struct just wraps the `axum::response::Response`.
// The only way to determine the error is to inspect the status code from the `Response` struct.
// In `/graphql` API, all responses are sent with `200` OK including errors, which leaves no way to deduce errors in the tracing middleware.
set_status_on_current_span(&response);
response.inner()
}
async fn handle_explain_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<gql::http::RawRequest>,
) -> graphql_frontend::ExplainResponse {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_explain_request",
"Handle explain request",
SpanVisibility::User,
|| {
Box::pin(
graphql_frontend::execute_explain(
state.expose_internal_errors,
&state.http_context,
&state.graphql_state,
&session,
&headers,
request,
)
.map(|(_operation_type, graphql_response)| graphql_response),
)
},
)
.await;
// Set the span as error if the response contains an error
set_status_on_current_span(&response);
response
}
async fn plugins_middleware(
State(engine_state): State<EngineState>,
Extension(session): Extension<Session>,
headers_map: HeaderMap,
request: Request<axum::body::Body>,
next: Next,
) -> axum::response::Result<axum::response::Response<Body>> {
let (parts, body) = request.into_parts();
let bytes = body
.collect()
.await
.map_err(|err| {
(reqwest::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
})?
.to_bytes();
let raw_request = bytes.clone();
// Check if the pre_parse_plugins_config is empty
let response =
match nonempty::NonEmpty::from_slice(&engine_state.plugin_configs.pre_parse_plugins) {
None => {
// If empty, do nothing and pass the request to the next middleware
let recreated_request = Request::from_parts(parts, axum::body::Body::from(bytes));
Ok::<_, axum::response::ErrorResponse>(next.run(recreated_request).await)
}
Some(pre_parse_plugins) => {
let response = pre_parse_plugins_handler(
&pre_parse_plugins,
&engine_state.http_context.client,
session.clone(),
&bytes,
headers_map.clone(),
)
.await?;
if let Some(response) = response {
Ok(response)
} else {
let recreated_request =
Request::from_parts(parts, axum::body::Body::from(bytes));
Ok(next.run(recreated_request).await)
}
}
}?;
let (parts, body) = response.into_parts();
let response_bytes = body
.collect()
.await
.map_err(|err| {
(reqwest::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
})?
.to_bytes();
if let Some(pre_response_plugins) =
nonempty::NonEmpty::from_slice(&engine_state.plugin_configs.pre_response_plugins)
{
pre_response_plugins_handler(
&pre_response_plugins,
&engine_state.http_context.client,
session,
&raw_request,
&response_bytes,
headers_map,
)?;
}
let recreated_response =
axum::response::Response::from_parts(parts, axum::body::Body::from(response_bytes));
Ok(recreated_response)
}
/// Handle a SQL request and execute it.
pub async fn handle_sql_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<sql::execute::SqlRequest>,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_sql_request",
"Handle SQL Request",
SpanVisibility::User,
|| {
Box::pin(async {
sql::execute::execute_sql(
Arc::new(headers),
state.sql_context.clone(),
Arc::new(session),
Arc::new(state.http_context.clone()),
&request,
)
.await
})
},
)
.await;
// Set the span as error if the response contains an error
set_status_on_current_span(&response);
match response {
Ok(r) => {
let mut response = (axum::http::StatusCode::OK, r).into_response();
response.headers_mut().insert(
CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
response
}
Err(e) => (
axum::http::StatusCode::BAD_REQUEST,
Json(e.to_error_response()),
)
.into_response(),
}
}
#[allow(clippy::print_stdout)]
/// Print any build warnings to stdout
fn print_warnings<T: Display>(warnings: Vec<T>) {
for warning in warnings {
println!("Warning: {warning}");
}
}
/// Build the engine state - include auth, metadata, and sql context.
fn build_state(
expose_internal_errors: execute::ExposeInternalErrors,
authn_config_path: &PathBuf,
metadata_path: &PathBuf,
enable_sql_interface: bool,
metadata_resolve_configuration: &metadata_resolve::configuration::Configuration,
) -> Result<EngineState, anyhow::Error> {
// Auth Config
let raw_auth_config = std::fs::read_to_string(authn_config_path)?;
let (auth_config, auth_warnings) =
resolve_auth_config(&raw_auth_config).map_err(StartupError::ReadAuth)?;
// Metadata
let raw_metadata = std::fs::read_to_string(metadata_path)?;
let metadata = open_dds::Metadata::from_json_str(&raw_metadata)?;
let (resolved_metadata, warnings) =
metadata_resolve::resolve(metadata, metadata_resolve_configuration)?;
let resolved_metadata = Arc::new(resolved_metadata);
print_warnings(auth_warnings);
print_warnings(warnings);
let http_context = HttpContext {
client: reqwest::Client::new(),
ndc_response_size_limit: None,
};
let plugin_configs = resolved_metadata.plugin_configs.clone();
let sql_context = if enable_sql_interface {
sql::catalog::Catalog::from_metadata(resolved_metadata.clone())
} else {
sql::catalog::Catalog::empty_from_metadata(resolved_metadata.clone())
};
let schema = graphql_schema::GDS {
metadata: resolved_metadata.clone(),
}
.build_schema()?;
let (jsonapi_catalog, _json_api_warnings) = jsonapi::Catalog::new(&resolved_metadata);
let state = EngineState {
expose_internal_errors,
http_context,
graphql_state: Arc::new(schema),
jsonapi_catalog: Arc::new(jsonapi_catalog),
resolved_metadata,
auth_config: Arc::new(auth_config),
sql_context: sql_context.into(),
plugin_configs: Arc::new(plugin_configs),
graphql_websocket_server: Arc::new(graphql_ws::WebSocketServer::new()),
};
Ok(state)
}
async fn handle_websocket_request(
headers: axum::http::header::HeaderMap,
State(engine_state): State<EngineState>,
ws: axum::extract::ws::WebSocketUpgrade,
) -> impl IntoResponse {
// Create the context for the websocket server
let context = graphql_ws::Context {
connection_expiry: graphql_ws::ConnectionExpiry::Never,
http_context: engine_state.http_context,
project_id: None, // project_id is not needed for OSS v3-engine.
expose_internal_errors: engine_state.expose_internal_errors,
schema: engine_state.graphql_state,
auth_config: engine_state.auth_config,
plugin_configs: engine_state.plugin_configs,
metrics: graphql_ws::NoOpWebSocketMetrics, // No metrics implementation
};
engine_state
.graphql_websocket_server
.upgrade_and_handle_websocket(ws, &headers, context)
}

View File

@ -1,5 +1,19 @@
pub mod build;
mod cors;
pub mod internal_flags;
mod middleware;
mod routes;
mod state;
mod types;
pub use cors::build_cors_layer;
pub use middleware::{
authentication_middleware, explain_request_tracing_middleware,
graphql_request_tracing_middleware, plugins_middleware, sql_request_tracing_middleware,
};
pub use routes::EngineRouter;
pub use state::build_state;
pub use types::{EngineState, StartupError};
// This is set by the build.rs script.
/// The version of the v3-engine release.

View File

@ -0,0 +1,209 @@
use crate::EngineState;
use crate::VERSION;
use axum::{
extract::State,
http::{HeaderMap, Request},
middleware::Next,
response::IntoResponse,
Extension,
};
use axum_core::body::Body;
use hasura_authn::authenticate;
use http_body_util::BodyExt;
use pre_parse_plugin::execute::pre_parse_plugins_handler;
use pre_response_plugin::execute::pre_response_plugins_handler;
use hasura_authn_core::Session;
use tracing_util::{SpanVisibility, TraceableHttpResponse};
/// Middleware to start tracing of the `/graphql` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the GraphQL request.
pub async fn graphql_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
use tracing_util::*;
let tracer = global_tracer();
let path = "/graphql";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
set_attribute_on_active_span(AttributeVisibility::Internal, "version", VERSION);
Box::pin(async move {
let mut response = next.run(request).await;
get_text_map_propagator(|propagator| {
propagator.inject_context(
&Context::current(),
&mut HeaderInjector(response.headers_mut()),
);
});
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// Middleware to start tracing of the `/v1/explain` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the GraphQL request.
pub async fn explain_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let path = "/v1/explain";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
Box::pin(async move {
let response = next.run(request).await;
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// Middleware to start tracing of the `/v1/sql` request.
/// This middleware must be active for the entire duration
/// of the request i.e. this middleware should be the
/// entry point and the exit point of the SQL request.
pub async fn sql_request_tracing_middleware(
request: Request<Body>,
next: Next,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let path = "/v1/sql";
tracer
.in_span_async_with_parent_context(
path,
path,
SpanVisibility::User,
&request.headers().clone(),
|| {
Box::pin(async move {
let response = next.run(request).await;
TraceableHttpResponse::new(response, path)
})
},
)
.await
.response
}
/// This middleware authenticates the incoming GraphQL request according to the
/// authentication configuration present in the `auth_config` of `EngineState`. The
/// result of the authentication is `hasura-authn-core::Identity`, which is then
/// made available to the GraphQL request handler.
pub async fn authentication_middleware<'a>(
State(engine_state): State<EngineState>,
headers_map: HeaderMap,
mut request: Request<Body>,
next: Next,
) -> axum::response::Result<axum::response::Response> {
let tracer = tracing_util::global_tracer();
let resolved_identity = tracer
.in_span_async(
"authentication_middleware",
"Authentication middleware",
SpanVisibility::Internal,
|| {
Box::pin(authenticate(
&headers_map,
&engine_state.http_context.client,
&engine_state.auth_config,
))
},
)
.await?;
request.extensions_mut().insert(resolved_identity);
Ok(next.run(request).await)
}
pub async fn plugins_middleware(
State(engine_state): State<EngineState>,
Extension(session): Extension<Session>,
headers_map: HeaderMap,
request: Request<axum::body::Body>,
next: Next,
) -> axum::response::Result<axum::response::Response<Body>> {
let (parts, body) = request.into_parts();
let bytes = body
.collect()
.await
.map_err(|err| {
(reqwest::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
})?
.to_bytes();
let raw_request = bytes.clone();
// Check if the pre_parse_plugins_config is empty
let response =
match nonempty::NonEmpty::from_slice(&engine_state.plugin_configs.pre_parse_plugins) {
None => {
// If empty, do nothing and pass the request to the next middleware
let recreated_request = Request::from_parts(parts, axum::body::Body::from(bytes));
Ok::<_, axum::response::ErrorResponse>(next.run(recreated_request).await)
}
Some(pre_parse_plugins) => {
let response = pre_parse_plugins_handler(
&pre_parse_plugins,
&engine_state.http_context.client,
session.clone(),
&bytes,
headers_map.clone(),
)
.await?;
if let Some(response) = response {
Ok(response)
} else {
let recreated_request =
Request::from_parts(parts, axum::body::Body::from(bytes));
Ok(next.run(recreated_request).await)
}
}
}?;
let (parts, body) = response.into_parts();
let response_bytes = body
.collect()
.await
.map_err(|err| {
(reqwest::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
})?
.to_bytes();
if let Some(pre_response_plugins) =
nonempty::NonEmpty::from_slice(&engine_state.plugin_configs.pre_response_plugins)
{
pre_response_plugins_handler(
&pre_response_plugins,
&engine_state.http_context.client,
session,
&raw_request,
&response_bytes,
headers_map,
)?;
}
let recreated_response =
axum::response::Response::from_parts(parts, axum::body::Body::from(response_bytes));
Ok(recreated_response)
}

View File

@ -0,0 +1,201 @@
mod sql;
pub use sql::handle_sql_request;
mod graphql;
pub use graphql::{handle_explain_request, handle_request, handle_websocket_request};
mod jsonapi;
pub use jsonapi::create_json_api_router;
use axum::{
extract::DefaultBodyLimit,
response::Html,
routing::{get, post},
Router,
};
use base64::engine::Engine;
use std::hash;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use crate::{
authentication_middleware, build_cors_layer, explain_request_tracing_middleware,
graphql_request_tracing_middleware, plugins_middleware, sql_request_tracing_middleware,
EngineState, StartupError,
};
const MB: usize = 1_048_576;
/// The main router for the engine.
pub struct EngineRouter {
/// The base router for the engine.
/// Contains /, /graphql, /v1/explain and /health routes.
base_router: Router,
/// The metadata routes for the introspection metadata file.
/// Contains /metadata and /metadata-hash routes.
metadata_routes: Option<Router>,
/// Routes for the SQL interface
sql_routes: Option<Router>,
/// Routes for the JSON:API interface
jsonapi_routes: Option<Router>,
/// The CORS layer for the engine.
cors_layer: Option<CorsLayer>,
}
impl EngineRouter {
pub fn new(state: EngineState) -> Self {
let graphql_ws_route = Router::new()
.route("/graphql", get(handle_websocket_request))
.layer(axum::middleware::from_fn(
graphql_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `graphql_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let graphql_route = Router::new()
.route("/graphql", post(handle_request))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
plugins_middleware,
))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(
graphql_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `graphql_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let explain_route = Router::new()
.route("/v1/explain", post(handle_explain_request))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(
explain_request_tracing_middleware,
))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `explain_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state);
let health_route = Router::new().route("/health", get(handle_health));
let base_routes = Router::new()
// serve graphiql at root
.route("/", get(graphiql))
// The '/graphql' route
.merge(graphql_route)
// The '/graphql' route for websocket
.merge(graphql_ws_route)
// The '/v1/explain' route
.merge(explain_route)
// The '/health' route
.merge(health_route)
// Set request payload limit to 10 MB
.layer(DefaultBodyLimit::max(10 * MB));
Self {
base_router: base_routes,
metadata_routes: None,
sql_routes: None,
jsonapi_routes: None,
cors_layer: None,
}
}
/// Serve the introspection metadata file and its hash at `/metadata` and `/metadata-hash` respectively.
/// This is a temporary workaround to enable the console to interact with an engine process running locally.
pub async fn add_metadata_routes(
&mut self,
introspection_metadata_path: &PathBuf,
) -> Result<(), StartupError> {
let file_contents = tokio::fs::read_to_string(introspection_metadata_path)
.await
.map_err(|err| StartupError::ReadSchema(err.into()))?;
let mut hasher = hash::DefaultHasher::new();
file_contents.hash(&mut hasher);
let hash = hasher.finish();
let base64_hash = base64::engine::general_purpose::STANDARD.encode(hash.to_ne_bytes());
let metadata_routes = Router::new()
.route("/metadata", get(|| async { file_contents }))
.route("/metadata-hash", get(|| async { base64_hash }));
self.metadata_routes = Some(metadata_routes);
Ok(())
}
pub fn add_sql_route(&mut self, state: EngineState) {
let sql_routes = Router::new()
.route("/v1/sql", post(handle_sql_request))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
authentication_middleware,
))
.layer(axum::middleware::from_fn(sql_request_tracing_middleware))
// *PLEASE DO NOT ADD ANY MIDDLEWARE
// BEFORE THE `explain_request_tracing_middleware`*
// Refer to it for more details.
.layer(TraceLayer::new_for_http())
.with_state(state);
self.sql_routes = Some(sql_routes);
}
pub fn add_jsonapi_route(&mut self, state: EngineState) {
let jsonapi_routes = create_json_api_router(state);
self.jsonapi_routes = Some(jsonapi_routes);
}
pub fn add_cors_layer(&mut self, allow_origin: &[String]) {
self.cors_layer = Some(build_cors_layer(allow_origin));
}
pub fn into_make_service(self) -> axum::routing::IntoMakeService<Router> {
let mut app = self.base_router;
// Merge the metadata routes if they exist.
if let Some(sql_routes) = self.sql_routes {
app = app.merge(sql_routes);
}
if let Some(jsonapi_routes) = self.jsonapi_routes {
app = app.merge(jsonapi_routes);
}
// Merge the metadata routes if they exist.
if let Some(metadata_routes) = self.metadata_routes {
app = app.merge(metadata_routes);
}
// Add the CORS layer if it exists.
if let Some(cors_layer) = self.cors_layer {
// It is important that this layer is added last, since it only affects
// the layers that precede it.
app = app.layer(cors_layer);
}
app.into_make_service()
}
}
/// Health check endpoint
async fn handle_health() -> reqwest::StatusCode {
reqwest::StatusCode::OK
}
async fn graphiql() -> Html<&'static str> {
Html(include_str!("index.html"))
}

View File

@ -0,0 +1,103 @@
use axum::{extract::State, response::IntoResponse, Extension, Json};
use futures_util::FutureExt;
use crate::EngineState;
use hasura_authn_core::Session;
use lang_graphql as gql;
use tracing_util::{set_status_on_current_span, SpanVisibility};
pub async fn handle_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<gql::http::RawRequest>,
) -> gql::http::Response {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_request",
"Handle request",
SpanVisibility::User,
|| {
{
Box::pin(
graphql_frontend::execute_query(
state.expose_internal_errors,
&state.http_context,
&state.graphql_state,
&session,
&headers,
request,
None,
)
.map(|(_operation_type, graphql_response)| graphql_response),
)
}
},
)
.await;
// Set the span as error if the response contains an error
// NOTE: Ideally, we should mark the root span as error in `graphql_request_tracing_middleware` function,
// the tracing middleware, where the span is initialized. It is possible by completing the implementation
// of `Traceable` trait for `AxumResponse` struct. The said struct just wraps the `axum::response::Response`.
// The only way to determine the error is to inspect the status code from the `Response` struct.
// In `/graphql` API, all responses are sent with `200` OK including errors, which leaves no way to deduce errors in the tracing middleware.
set_status_on_current_span(&response);
response.inner()
}
pub async fn handle_explain_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<gql::http::RawRequest>,
) -> graphql_frontend::ExplainResponse {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_explain_request",
"Handle explain request",
SpanVisibility::User,
|| {
Box::pin(
graphql_frontend::execute_explain(
state.expose_internal_errors,
&state.http_context,
&state.graphql_state,
&session,
&headers,
request,
)
.map(|(_operation_type, graphql_response)| graphql_response),
)
},
)
.await;
// Set the span as error if the response contains an error
set_status_on_current_span(&response);
response
}
pub async fn handle_websocket_request(
headers: axum::http::header::HeaderMap,
State(engine_state): State<EngineState>,
ws: axum::extract::ws::WebSocketUpgrade,
) -> impl IntoResponse {
// Create the context for the websocket server
let context = graphql_ws::Context {
connection_expiry: graphql_ws::ConnectionExpiry::Never,
http_context: engine_state.http_context,
project_id: None, // project_id is not needed for OSS v3-engine.
expose_internal_errors: engine_state.expose_internal_errors,
schema: engine_state.graphql_state,
auth_config: engine_state.auth_config,
plugin_configs: engine_state.plugin_configs,
metrics: graphql_ws::NoOpWebSocketMetrics, // No metrics implementation
};
engine_state
.graphql_websocket_server
.upgrade_and_handle_websocket(ws, &headers, context)
}

View File

@ -12,13 +12,13 @@ use tracing_util::{set_status_on_current_span, SpanVisibility, Traceable};
use crate::{authentication_middleware, EngineState};
pub(crate) fn create_json_api_router(state: EngineState) -> axum::Router {
pub fn create_json_api_router(state: EngineState) -> axum::Router {
let router = Router::new()
.route("/__schema", get(handle_schema))
.route("/__schema", get(handle_rest_schema))
// TODO: update method GET; for now we are only supporting queries. And
// in JSON:API spec, all queries have the GET method. Not even HEAD is
// supported. So this should be fine.
.route("/*path", get(handle_request))
.route("/*path", get(handle_rest_request))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
@ -56,7 +56,7 @@ impl Traceable for JsonApiSchemaResponse {
}
}
async fn handle_schema(
async fn handle_rest_schema(
axum::extract::State(state): axum::extract::State<EngineState>,
Extension(session): Extension<Session>,
) -> impl IntoResponse {
@ -78,7 +78,7 @@ async fn handle_schema(
)
}
async fn handle_request(
async fn handle_rest_request(
request_headers: HeaderMap,
method: Method,
uri: Uri,

View File

@ -0,0 +1,55 @@
use crate::EngineState;
use axum::{extract::State, response::IntoResponse, Extension, Json};
use reqwest::header::CONTENT_TYPE;
use std::sync::Arc;
use hasura_authn_core::Session;
use tracing_util::{set_status_on_current_span, SpanVisibility};
/// Handle a SQL request and execute it.
pub async fn handle_sql_request(
headers: axum::http::header::HeaderMap,
State(state): State<EngineState>,
Extension(session): Extension<Session>,
Json(request): Json<sql::execute::SqlRequest>,
) -> axum::response::Response {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"handle_sql_request",
"Handle SQL Request",
SpanVisibility::User,
|| {
Box::pin(async {
sql::execute::execute_sql(
Arc::new(headers),
state.sql_context.clone(),
Arc::new(session),
Arc::new(state.http_context.clone()),
&request,
)
.await
})
},
)
.await;
// Set the span as error if the response contains an error
set_status_on_current_span(&response);
match response {
Ok(r) => {
let mut response = (axum::http::StatusCode::OK, r).into_response();
response.headers_mut().insert(
CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
response
}
Err(e) => (
axum::http::StatusCode::BAD_REQUEST,
Json(e.to_error_response()),
)
.into_response(),
}
}

View File

@ -0,0 +1,70 @@
use crate::{EngineState, StartupError};
use hasura_authn::resolve_auth_config;
use std::fmt::Display;
use std::path::PathBuf;
use std::sync::Arc;
use execute::HttpContext;
#[allow(clippy::print_stdout)]
/// Print any build warnings to stdout
fn print_warnings<T: Display>(warnings: Vec<T>) {
for warning in warnings {
println!("Warning: {warning}");
}
}
/// Build the engine state - include auth, metadata, and sql context.
pub fn build_state(
expose_internal_errors: execute::ExposeInternalErrors,
authn_config_path: &PathBuf,
metadata_path: &PathBuf,
enable_sql_interface: bool,
metadata_resolve_configuration: &metadata_resolve::configuration::Configuration,
) -> Result<EngineState, anyhow::Error> {
// Auth Config
let raw_auth_config = std::fs::read_to_string(authn_config_path)?;
let (auth_config, auth_warnings) =
resolve_auth_config(&raw_auth_config).map_err(StartupError::ReadAuth)?;
// Metadata
let raw_metadata = std::fs::read_to_string(metadata_path)?;
let metadata = open_dds::Metadata::from_json_str(&raw_metadata)?;
let (resolved_metadata, warnings) =
metadata_resolve::resolve(metadata, metadata_resolve_configuration)?;
let resolved_metadata = Arc::new(resolved_metadata);
print_warnings(auth_warnings);
print_warnings(warnings);
let http_context = HttpContext {
client: reqwest::Client::new(),
ndc_response_size_limit: None,
};
let plugin_configs = resolved_metadata.plugin_configs.clone();
let sql_context = if enable_sql_interface {
sql::catalog::Catalog::from_metadata(resolved_metadata.clone())
} else {
sql::catalog::Catalog::empty_from_metadata(resolved_metadata.clone())
};
let schema = graphql_schema::GDS {
metadata: resolved_metadata.clone(),
}
.build_schema()?;
let (jsonapi_catalog, _json_api_warnings) = jsonapi::Catalog::new(&resolved_metadata);
let state = EngineState {
expose_internal_errors,
http_context,
graphql_state: Arc::new(schema),
jsonapi_catalog: Arc::new(jsonapi_catalog),
resolved_metadata,
auth_config: Arc::new(auth_config),
sql_context: sql_context.into(),
plugin_configs: Arc::new(plugin_configs),
graphql_websocket_server: Arc::new(graphql_ws::WebSocketServer::new()),
};
Ok(state)
}

View File

@ -0,0 +1,37 @@
use hasura_authn::AuthConfig;
use metadata_resolve::LifecyclePluginConfigs;
use std::sync::Arc;
use execute::HttpContext;
use graphql_schema::GDS;
use lang_graphql as gql;
use tracing_util::{ErrorVisibility, TraceableError};
#[derive(Clone)] // Cheap to clone as heavy fields are wrapped in `Arc`
pub struct EngineState {
pub expose_internal_errors: execute::ExposeInternalErrors,
pub http_context: HttpContext,
pub graphql_state: Arc<gql::schema::Schema<GDS>>,
pub resolved_metadata: Arc<metadata_resolve::Metadata>,
pub jsonapi_catalog: Arc<jsonapi::Catalog>,
pub auth_config: Arc<AuthConfig>,
pub sql_context: Arc<sql::catalog::Catalog>,
pub plugin_configs: Arc<LifecyclePluginConfigs>,
pub graphql_websocket_server:
Arc<graphql_ws::WebSocketServer<graphql_ws::NoOpWebSocketMetrics>>,
}
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum StartupError {
#[error("could not read the auth config - {0}")]
ReadAuth(anyhow::Error),
#[error("failed to build engine state - {0}")]
ReadSchema(anyhow::Error),
}
impl TraceableError for StartupError {
fn visibility(&self) -> tracing_util::ErrorVisibility {
ErrorVisibility::User
}
}