pre-execution plugin middleware (#771)

The RFC:
https://docs.google.com/document/d/1NB9fA6J8_dKtWknJkfTiqN5qWhPVDfYVOKFK25gYE7Y/edit

The JIRA: https://hasurahq.atlassian.net/browse/V3ENGINE-234

<!-- The PR description should answer 2 (maybe 3) important questions:
-->

### What

<!-- What is this PR trying to accomplish (and why, if it's not
obvious)? -->

<!-- Consider: do we need to add a changelog entry? -->

This PR adds a new middleware to the `/graphql` endpoint handler. This
new middleware will be used to handle the pre-execution plugins.

### How

<!-- How is it trying to accomplish it (what are the implementation
steps)? -->

We are doing something similar to
https://github.com/tokio-rs/axum/blob/axum-v0.6.20/examples/consume-body-in-extractor-or-middleware/src/main.rs

V3_GIT_ORIGIN_REV_ID: a7e0a7a252efd1f266b3e90df23a8307a9b35fc7
This commit is contained in:
paritosh-08 2024-07-03 18:27:38 +05:30 committed by hasura-bot
parent e380876823
commit 08014cab88
13 changed files with 541 additions and 3 deletions

16
v3/Cargo.lock generated
View File

@ -1726,6 +1726,7 @@ dependencies = [
"metadata-resolve",
"open-dds",
"opendds-derive",
"pre-execution-plugin",
"pretty_assertions",
"reqwest",
"schema",
@ -3560,6 +3561,21 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "pre-execution-plugin"
version = "0.1.0"
dependencies = [
"axum",
"hasura-authn-core",
"lang-graphql",
"reqwest",
"schemars",
"serde",
"serde_json",
"thiserror",
"tracing-util",
]
[[package]]
name = "pretty_assertions"
version = "1.4.0"

View File

@ -14,6 +14,7 @@ members = [
"crates/metadata-resolve",
"crates/metadata-schema-generator",
"crates/open-dds",
"crates/plugins/*",
"crates/query-usage-analytics",
"crates/schema",
"crates/sql",

View File

@ -34,7 +34,7 @@ impl SessionVariableValue {
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
pub struct SessionVariables(HashMap<SessionVariable, SessionVariableValue>);
impl SessionVariables {
@ -44,7 +44,7 @@ impl SessionVariables {
}
// The privilege with which a request is executed
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
pub struct Session {
pub role: Role,
pub variables: SessionVariables,

View File

@ -28,6 +28,7 @@ hasura-authn-webhook = { path = "../auth/hasura-authn-webhook" }
lang-graphql = { path = "../lang-graphql" }
open-dds = { path = "../open-dds" }
opendds-derive = { path = "../utils/opendds-derive" }
pre-execution-plugin = { path = "../plugins/pre-execution-plugin" }
schema = { path = "../schema" }
sql = { path = "../sql" }
tracing-util = { path = "../utils/tracing-util" }

View File

@ -14,6 +14,9 @@ use axum::{
};
use clap::Parser;
use pre_execution_plugin::{
configuration::PrePluginConfig, execute::pre_execution_plugins_handler,
};
use reqwest::header::CONTENT_TYPE;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
@ -23,9 +26,15 @@ use tracing_util::{
};
use base64::engine::Engine;
use engine::authentication::{AuthConfig, AuthConfig::V1 as V1AuthConfig, AuthModeConfig};
use engine::internal_flags::{resolve_unstable_features, UnstableFeature};
use engine::VERSION;
use engine::{
authentication::{
AuthConfig::{self, V1 as V1AuthConfig},
AuthModeConfig,
},
plugins::read_pre_execution_plugins_config,
};
use execute::HttpContext;
use hasura_authn_core::Session;
use hasura_authn_jwt::auth as jwt_auth;
@ -92,6 +101,9 @@ struct ServerOptions {
value_delimiter = ','
)]
unstable_features: Vec<UnstableFeature>,
/// The configuration file used for authentication.
#[arg(long, value_name = "PATH", env = "pre_execution_plugins_path")]
pre_execution_plugins_path: Option<PathBuf>,
/// Whether internal errors should be shown or censored.
/// It is recommended to only show errors while developing since internal errors may contain
@ -105,6 +117,7 @@ struct EngineState {
http_context: HttpContext,
schema: gql::schema::Schema<GDS>,
auth_config: AuthConfig,
pre_execution_plugins_config: Vec<PrePluginConfig>,
sql_context: sql::catalog::Context,
}
@ -171,11 +184,14 @@ async fn shutdown_signal() {
}
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
enum StartupError {
#[error("could not read the auth config - {0}")]
ReadAuth(anyhow::Error),
#[error("failed to build engine state - {0}")]
ReadSchema(anyhow::Error),
#[error("could not read the pre-execution plugins config - {0}")]
ReadPrePlugin(anyhow::Error),
}
impl TraceableError for StartupError {
@ -202,6 +218,10 @@ impl EngineRouter {
fn new(state: Arc<EngineState>) -> Self {
let graphql_route = Router::new()
.route("/graphql", post(handle_request))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
pre_execution_plugins_middleware,
))
.layer(axum::middleware::from_fn(
hasura_authn_core::resolve_session,
))
@ -338,6 +358,7 @@ async fn start_engine(server: &ServerOptions) -> Result<(), StartupError> {
expose_internal_errors,
&server.authn_config_path,
&server.metadata_path,
&server.pre_execution_plugins_path,
metadata_resolve_configuration,
)
.map_err(StartupError::ReadSchema)?;
@ -618,6 +639,32 @@ async fn handle_explain_request(
response
}
async fn pre_execution_plugins_middleware<'a, B>(
State(engine_state): State<Arc<EngineState>>,
Extension(session): Extension<Session>,
headers_map: HeaderMap,
request: Request<B>,
next: Next<axum::body::Body>,
) -> axum::response::Result<axum::response::Response>
where
B: HttpBody,
B::Error: Display,
{
let (request, response) = pre_execution_plugins_handler(
&engine_state.pre_execution_plugins_config,
&engine_state.http_context.client,
session,
request,
headers_map,
)
.await?;
match response {
Some(response) => Ok(response),
None => Ok(next.run(request).await),
}
}
/// Handle a SQL request and execute it.
async fn handle_sql_request(
State(state): State<Arc<EngineState>>,
@ -669,9 +716,13 @@ fn build_state(
expose_internal_errors: execute::ExposeInternalErrors,
authn_config_path: &PathBuf,
metadata_path: &PathBuf,
pre_execution_plugins_path: &Option<PathBuf>,
metadata_resolve_configuration: metadata_resolve::configuration::Configuration,
) -> Result<Arc<EngineState>, anyhow::Error> {
let auth_config = read_auth_config(authn_config_path).map_err(StartupError::ReadAuth)?;
let pre_execution_plugins_config =
read_pre_execution_plugins_config(pre_execution_plugins_path)
.map_err(StartupError::ReadPrePlugin)?;
let raw_metadata = std::fs::read_to_string(metadata_path)?;
let metadata = open_dds::Metadata::from_json_str(&raw_metadata)?;
let resolved_metadata = metadata_resolve::resolve(metadata, metadata_resolve_configuration)?;
@ -689,6 +740,7 @@ fn build_state(
http_context,
schema,
auth_config,
pre_execution_plugins_config,
sql_context,
});
Ok(state)

View File

@ -1,6 +1,7 @@
pub mod authentication;
pub mod build;
pub mod internal_flags;
pub mod plugins;
// This is set by the build.rs script.
/// The version of the v3-engine release.

View File

@ -0,0 +1,29 @@
use pre_execution_plugin::configuration::PrePluginConfig;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "version", content = "definition")]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
/// Definition of the Pre-execution Plugin configuration used by the API server.
enum PreExecutionPluginConfiguration {
V1(PrePluginConfig),
}
pub fn read_pre_execution_plugins_config(
path: &Option<std::path::PathBuf>,
) -> Result<Vec<PrePluginConfig>, anyhow::Error> {
let pre_plugins: Vec<PreExecutionPluginConfiguration> = match path {
Some(path) => {
let raw_pre_execution_plugins_config = std::fs::read_to_string(path)?;
Ok::<_, anyhow::Error>(serde_json::from_str(&raw_pre_execution_plugins_config)?)
}
None => Ok(vec![]),
}?;
Ok(pre_plugins
.into_iter()
.map(|p| match p {
PreExecutionPluginConfiguration::V1(config) => config,
})
.collect())
}

View File

@ -133,6 +133,23 @@ impl Response {
}
}
pub fn error_message_with_status_and_details(
status_code: http::status::StatusCode,
message: String,
details: serde_json::Value,
) -> Self {
Self {
status_code,
headers: http::HeaderMap::default(),
data: None,
errors: Some(nonempty![GraphQLError {
message,
path: None,
extensions: Some(Extensions { details }),
}]),
}
}
pub fn error(error: GraphQLError, headers: http::HeaderMap) -> Self {
Self {
status_code: http::status::StatusCode::OK,

View File

@ -0,0 +1,21 @@
[package]
name = "pre-execution-plugin"
version.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
hasura-authn-core = { path = "../../auth/hasura-authn-core" }
lang-graphql = { path = "../../lang-graphql" }
tracing-util = { path = "../../utils/tracing-util" }
axum = { workspace = true }
reqwest = { workspace = true, features = ["json"] }
schemars = { workspace = true, features = ["smol_str", "url"] }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
[lints]
workspace = true

View File

@ -0,0 +1,70 @@
use reqwest::Url;
use schemars::JsonSchema;
use serde::{de::Error as SerdeDeError, Deserialize, Deserializer, Serialize, Serializer};
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema, PartialEq)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
#[schemars(title = "RequestConfig")]
pub struct RequestConfig {
pub headers: bool,
pub session: bool,
pub raw_request: RawRequestConfig,
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema, PartialEq)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
#[schemars(title = "RawRequestConfig")]
pub struct RawRequestConfig {
pub query: bool,
pub variables: bool,
}
fn serialize_url<S>(url: &Url, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_str(url.as_str())
}
fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
where
D: Deserializer<'de>,
{
let buf = String::deserialize(deserializer)?;
Url::parse(&buf).map_err(SerdeDeError::custom)
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema, PartialEq)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
#[schemars(title = "PrePluginConfig")]
#[schemars(example = "PrePluginConfig::example")]
pub struct PrePluginConfig {
pub name: String,
#[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
pub url: Url,
pub request: RequestConfig,
}
impl PrePluginConfig {
fn example() -> Self {
serde_json::from_str(
r#"{
"name": "example",
"url": "http://example.com",
"request": {
"headers": true,
"session": true,
"rawRequest": {
"query": true,
"variables": true
}
}
}"#,
)
.unwrap()
}
}

View File

@ -0,0 +1,311 @@
use std::{collections::HashMap, fmt::Display};
use axum::{
body::HttpBody,
http::{HeaderMap, Request, StatusCode},
response::IntoResponse,
};
use serde::Serialize;
use crate::configuration::PrePluginConfig;
use hasura_authn_core::Session;
use lang_graphql::{ast::common as ast, http::RawRequest};
use thiserror::Error;
use tracing_util::{
set_attribute_on_active_span, ErrorVisibility, SpanVisibility, Traceable, TraceableError,
};
#[derive(Error, Debug)]
pub enum Error {
#[error("Error while making the HTTP request to the pre-execution plugin {0} - {1}")]
ErrorWhileMakingHTTPRequestToTheHook(String, reqwest::Error),
#[error("Reqwest error: {0}")]
ReqwestError(reqwest::Error),
#[error("Unexpected status code: {0}")]
UnexpectedStatusCode(u16),
#[error("plugin response parse error: {0}")]
PluginResponseParseError(serde_json::error::Error),
}
impl TraceableError for Error {
fn visibility(&self) -> ErrorVisibility {
ErrorVisibility::Internal
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
lang_graphql::http::Response::error_message_with_status(
StatusCode::INTERNAL_SERVER_ERROR,
self.to_string(),
)
.into_response()
}
}
#[derive(Debug, Clone)]
pub enum ErrorResponse {
UserError(Vec<u8>),
InternalError(Option<Vec<u8>>),
}
impl std::fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let message = match self {
ErrorResponse::UserError(error) | ErrorResponse::InternalError(Some(error)) => {
let error = serde_json::from_slice::<serde_json::Value>(error)
.map_err(|_| std::fmt::Error)?;
error.to_string()
}
ErrorResponse::InternalError(None) => String::new(),
};
write!(f, "{message}")
}
}
impl Traceable for PreExecutePluginResponse {
type ErrorType<'a> = ErrorResponse;
fn get_error(&self) -> Option<ErrorResponse> {
match self {
PreExecutePluginResponse::Continue | PreExecutePluginResponse::Return(_) => None,
PreExecutePluginResponse::ReturnError(err) => Some(err.clone()),
}
}
}
impl TraceableError for ErrorResponse {
fn visibility(&self) -> ErrorVisibility {
match &self {
ErrorResponse::UserError(_) => ErrorVisibility::User,
ErrorResponse::InternalError(_) => ErrorVisibility::Internal,
}
}
}
#[derive(Debug, Clone)]
pub enum PreExecutePluginResponse {
Return(Vec<u8>),
Continue,
ReturnError(ErrorResponse),
}
#[derive(Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct RawRequestBody {
pub query: Option<String>,
pub variables: Option<HashMap<ast::Name, serde_json::Value>>,
pub operation_name: Option<ast::Name>,
}
#[derive(Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct PreExecutePluginRequestBody {
pub session: Option<Session>,
pub raw_request: RawRequestBody,
}
fn build_request(
http_client: &reqwest::Client,
config: &PrePluginConfig,
client_headers: &HeaderMap,
session: &Session,
raw_request: &RawRequest,
) -> reqwest::RequestBuilder {
let mut pre_plugin_headers = tracing_util::get_trace_headers();
if config.request.headers {
pre_plugin_headers.extend(client_headers.clone());
}
let mut request_builder = http_client
.post(config.url.clone())
.headers(pre_plugin_headers);
let mut request_body = PreExecutePluginRequestBody {
session: None,
raw_request: RawRequestBody {
query: None,
variables: None,
operation_name: raw_request.operation_name.clone(),
},
};
if config.request.session {
request_body.session = Some(session.clone());
};
if config.request.raw_request.query {
request_body.raw_request.query = Some(raw_request.query.clone());
};
if config.request.raw_request.variables {
request_body
.raw_request
.variables
.clone_from(&raw_request.variables);
};
request_builder = request_builder.json(&request_body);
request_builder
}
pub async fn execute_plugin(
http_client: &reqwest::Client,
config: &PrePluginConfig,
client_headers: &HeaderMap,
session: &Session,
raw_request: &RawRequest,
) -> Result<PreExecutePluginResponse, Error> {
let tracer = tracing_util::global_tracer();
let response = tracer
.in_span_async(
"request_to_webhook",
"Send request to webhook",
SpanVisibility::Internal,
|| {
Box::pin(async {
let http_request_builder =
build_request(http_client, config, client_headers, session, raw_request);
let req = http_request_builder.build().map_err(Error::ReqwestError)?;
http_client.execute(req).await.map_err(|e| {
Error::ErrorWhileMakingHTTPRequestToTheHook(config.name.clone(), e)
})
})
},
)
.await?;
match response.status() {
StatusCode::NO_CONTENT => Ok(PreExecutePluginResponse::Continue),
StatusCode::OK => {
let body = response.bytes().await.map_err(Error::ReqwestError)?;
Ok(PreExecutePluginResponse::Return(body.to_vec()))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let body = response.bytes().await.map_err(Error::ReqwestError)?;
Ok(PreExecutePluginResponse::ReturnError(
ErrorResponse::InternalError(Some(body.to_vec())),
))
}
StatusCode::BAD_REQUEST => {
let body = response.bytes().await.map_err(Error::ReqwestError)?;
Ok(PreExecutePluginResponse::ReturnError(
ErrorResponse::UserError(body.to_vec()),
))
}
_ => Err(Error::UnexpectedStatusCode(response.status().as_u16())),
}
}
pub async fn pre_execution_plugins_handler<'a, B>(
pre_execution_plugins_config: &Vec<PrePluginConfig>,
http_client: &reqwest::Client,
session: Session,
request: Request<B>,
headers_map: HeaderMap,
) -> axum::response::Result<(Request<axum::body::Body>, Option<axum::response::Response>)>
where
B: HttpBody,
B::Error: Display,
{
let (parts, body) = request.into_parts();
let bytes = body
.collect()
.await
.map_err(|err| {
(reqwest::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
})?
.to_bytes();
let tracer = tracing_util::global_tracer();
let mut response = None;
let raw_request =
serde_json::from_slice::<RawRequest>(&bytes).map_err(Error::PluginResponseParseError)?;
for pre_plugin_config in pre_execution_plugins_config {
let plugin_response = tracer
.in_span_async(
"pre_execution_plugin_middleware",
"Pre-execution Plugin middleware",
SpanVisibility::Internal,
|| {
Box::pin(async {
set_attribute_on_active_span(
tracing_util::AttributeVisibility::Default,
"plugin.name",
pre_plugin_config.name.clone(),
);
let plugin_response = execute_plugin(
http_client,
pre_plugin_config,
&headers_map,
&session,
&raw_request,
)
.await;
if let Ok(PreExecutePluginResponse::ReturnError(
ErrorResponse::InternalError(error_value),
)) = &plugin_response
{
let error_value = serde_json::from_slice::<serde_json::Value>(
error_value.as_ref().unwrap_or(&vec![]),
)
.map_err(Error::PluginResponseParseError)?;
set_attribute_on_active_span(
tracing_util::AttributeVisibility::Default,
"plugin.internal_error",
error_value.to_string(),
);
};
if let Ok(PreExecutePluginResponse::ReturnError(
ErrorResponse::UserError(error_value),
)) = &plugin_response
{
let error_value =
serde_json::from_slice::<serde_json::Value>(error_value)
.map_err(Error::PluginResponseParseError)?;
set_attribute_on_active_span(
tracing_util::AttributeVisibility::Default,
"plugin.user_error",
error_value.to_string(),
);
}
plugin_response
})
},
)
.await?;
match plugin_response {
PreExecutePluginResponse::Return(value) => {
response = Some(value.into_response());
break;
}
PreExecutePluginResponse::Continue => (),
PreExecutePluginResponse::ReturnError(ErrorResponse::UserError(error_value)) => {
let error_value = serde_json::from_slice::<serde_json::Value>(&error_value)
.map_err(Error::PluginResponseParseError)?;
let user_error_response =
lang_graphql::http::Response::error_message_with_status_and_details(
reqwest::StatusCode::BAD_REQUEST,
format!(
"User error in pre-execution plugin {0}",
pre_plugin_config.name
),
error_value,
)
.into_response();
response = Some(user_error_response);
break;
}
PreExecutePluginResponse::ReturnError(ErrorResponse::InternalError(_error_value)) => {
let internal_error_response =
lang_graphql::http::Response::error_message_with_status(
reqwest::StatusCode::INTERNAL_SERVER_ERROR,
format!(
"Internal error in pre-execution plugin {0}",
pre_plugin_config.name
),
)
.into_response();
response = Some(internal_error_response);
break;
}
};
}
Ok((
Request::from_parts(parts, axum::body::Body::from(bytes)),
response,
))
}

View File

@ -0,0 +1,2 @@
pub mod configuration;
pub mod execute;

17
v3/pre_plugins.json Normal file
View File

@ -0,0 +1,17 @@
[
{
"version": "v1",
"definition": {
"name": "example",
"url": "http://localhost:5000/allow_list",
"request": {
"headers": false,
"session": true,
"rawRequest": {
"query": true,
"variables": true
}
}
}
}
]