diff --git a/v3/Cargo.lock b/v3/Cargo.lock index 0eddcdff859..2c2c1832b36 100644 --- a/v3/Cargo.lock +++ b/v3/Cargo.lock @@ -955,6 +955,7 @@ dependencies = [ "base64 0.21.7", "bincode", "build-data", + "bytes", "clap 4.5.4", "criterion", "derive_more", @@ -968,6 +969,7 @@ dependencies = [ "json_value_merge", "lang-graphql", "lazy_static", + "mockito", "ndc-client", "nonempty", "open-dds", diff --git a/v3/crates/engine/Cargo.toml b/v3/crates/engine/Cargo.toml index ac5238b9b96..2e44eb18dbd 100644 --- a/v3/crates/engine/Cargo.toml +++ b/v3/crates/engine/Cargo.toml @@ -32,6 +32,7 @@ async-recursion = "1.0.5" axum = { version = "0.6.20" } base64 = "0.21.2" bincode = "1.3.3" +bytes = "1.6.0" clap = { version = "4", features = ["derive", "env"] } derive_more = "0.99.17" futures = "0.3.29" @@ -66,8 +67,9 @@ build-data = "0.1.5" # To set short commit-sha at build time [dev-dependencies] criterion = { version = "0.4", features = ["html_reports", "async_tokio"] } goldenfile = "1.4.5" -tokio-test = "0.4.2" +mockito = { version = "1.1.0", default-features = false, features = [] } pretty_assertions = "1.3.0" +tokio-test = "0.4.2" [package.metadata.cargo-machete] ignored = [ diff --git a/v3/crates/engine/benches/execute.rs b/v3/crates/engine/benches/execute.rs index dcc8214c684..552ea4639b0 100644 --- a/v3/crates/engine/benches/execute.rs +++ b/v3/crates/engine/benches/execute.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use engine::execute::plan::{execute_mutation_plan, execute_query_plan, generate_request_plan}; -use engine::execute::{execute_query_internal, generate_ir}; +use engine::execute::{execute_query_internal, generate_ir, HttpContext}; use engine::schema::GDS; use hasura_authn_core::Identity; use lang_graphql::http::RawRequest; @@ -47,8 +47,10 @@ pub fn bench_execute( let gds = GDS::new(open_dds::traits::OpenDd::deserialize(metadata).unwrap()).unwrap(); let schema = GDS::build_schema(&gds).unwrap(); - - let http_client = reqwest::Client::new(); + let http_context = HttpContext { + client: reqwest::Client::new(), + ndc_response_size_limit: None, + }; let runtime = Runtime::new().unwrap(); let query = fs::read_to_string(request_path).unwrap(); @@ -132,10 +134,10 @@ pub fn bench_execute( b.to_async(*runtime).iter(|| async { match generate_request_plan(&ir).unwrap() { engine::execute::plan::RequestPlan::QueryPlan(query_plan) => { - execute_query_plan(&http_client, query_plan, None).await + execute_query_plan(&http_context, query_plan, None).await } engine::execute::plan::RequestPlan::MutationPlan(mutation_plan) => { - execute_mutation_plan(&http_client, mutation_plan, None).await + execute_mutation_plan(&http_context, mutation_plan, None).await } } }) @@ -148,7 +150,7 @@ pub fn bench_execute( &(&runtime, &schema, raw_request), |b, (runtime, schema, request)| { b.to_async(*runtime).iter(|| async { - execute_query_internal(&http_client, schema, &session, request.clone(), None) + execute_query_internal(&http_context, schema, &session, request.clone(), None) .await .unwrap() }) diff --git a/v3/crates/engine/bin/engine/main.rs b/v3/crates/engine/bin/engine/main.rs index 990a136d5ba..dfcfa2bd04e 100644 --- a/v3/crates/engine/bin/engine/main.rs +++ b/v3/crates/engine/bin/engine/main.rs @@ -19,7 +19,10 @@ use tracing_util::{ TraceableError, TraceableHttpResponse, }; -use engine::authentication::{AuthConfig, AuthConfig::V1 as V1AuthConfig, AuthModeConfig}; +use engine::{ + authentication::{AuthConfig, AuthConfig::V1 as V1AuthConfig, AuthModeConfig}, + execute::HttpContext, +}; use engine::{schema::GDS, VERSION}; use hasura_authn_core::Session; use hasura_authn_jwt::auth as jwt_auth; @@ -43,7 +46,7 @@ struct ServerOptions { } struct EngineState { - http_client: reqwest::Client, + http_context: HttpContext, schema: gql::schema::Schema, auth_config: AuthConfig, } @@ -124,8 +127,12 @@ async fn start_engine(server: &ServerOptions) -> Result<(), StartupError> { let auth_config = read_auth_config(&server.authn_config_path).map_err(StartupError::ReadAuth)?; let schema = read_schema(&server.metadata_path).map_err(StartupError::ReadSchema)?; + let http_context = HttpContext { + client: reqwest::Client::new(), + ndc_response_size_limit: None, + }; let state = Arc::new(EngineState { - http_client: reqwest::Client::new(), + http_context, schema, auth_config, }); @@ -313,7 +320,7 @@ where V1AuthConfig(auth_config) => match &auth_config.mode { AuthModeConfig::Webhook(webhook_config) => { webhook::authenticate_request( - &engine_state.http_client, + &engine_state.http_context.client, webhook_config, &headers_map, auth_config.allow_role_emulation_by.clone(), @@ -323,7 +330,7 @@ where } AuthModeConfig::Jwt(jwt_secret_config) => { jwt_auth::authenticate_request( - &engine_state.http_client, + &engine_state.http_context.client, *jwt_secret_config.clone(), auth_config.allow_role_emulation_by.clone(), &headers_map, @@ -355,7 +362,7 @@ async fn handle_request( let response = tracer .in_span_async("Handle request", SpanVisibility::User, || { Box::pin(engine::execute::execute_query( - &state.http_client, + &state.http_context, &state.schema, &session, request, @@ -383,7 +390,7 @@ async fn handle_explain_request( let response = tracer .in_span_async("Handle explain request", SpanVisibility::User, || { Box::pin(engine::execute::explain::execute_explain( - &state.http_client, + &state.http_context, &state.schema, &session, request, diff --git a/v3/crates/engine/src/execute.rs b/v3/crates/engine/src/execute.rs index 4918ca334a1..0011561814c 100644 --- a/v3/crates/engine/src/execute.rs +++ b/v3/crates/engine/src/execute.rs @@ -28,6 +28,14 @@ pub mod plan; pub mod process_response; pub mod remote_joins; +/// Context for making HTTP requests +pub struct HttpContext { + /// The HTTP client to use for making requests + pub client: reqwest::Client, + /// Response size limit for NDC requests + pub ndc_response_size_limit: Option, +} + #[derive(Debug)] /// A simple wrapper around a reference of GraphQL errors pub struct GraphQLErrors<'a>(pub &'a nonempty::NonEmpty); @@ -83,13 +91,13 @@ impl Traceable for ExecuteOrExplainResponse { pub struct ProjectId(pub String); pub async fn execute_query( - http_client: &reqwest::Client, + http_context: &HttpContext, schema: &Schema, session: &Session, request: RawRequest, project_id: Option, ) -> GraphQLResponse { - execute_query_internal(http_client, schema, session, request, project_id) + execute_query_internal(http_context, schema, session, request, project_id) .await .unwrap_or_else(|e| GraphQLResponse(Response::error(e.to_graphql_error(None)))) } @@ -115,14 +123,14 @@ impl TraceableError for GraphQlValidationError { /// Executes a GraphQL query pub async fn execute_query_internal( - http_client: &reqwest::Client, + http_context: &HttpContext, schema: &gql::schema::Schema, session: &Session, raw_request: gql::http::RawRequest, project_id: Option, ) -> Result { let query_response = execute_request_internal( - http_client, + http_context, schema, session, raw_request, @@ -142,7 +150,7 @@ pub async fn execute_query_internal( /// Executes or explains (query plan) a GraphQL query pub async fn execute_request_internal( - http_client: &reqwest::Client, + http_context: &HttpContext, schema: &gql::schema::Schema, session: &Session, raw_request: gql::http::RawRequest, @@ -227,7 +235,7 @@ pub async fn execute_request_internal( let execute_query_result = match request_plan { plan::RequestPlan::MutationPlan(mutation_plan) => { plan::execute_mutation_plan( - http_client, + http_context, mutation_plan, project_id, ) @@ -235,7 +243,7 @@ pub async fn execute_request_internal( } plan::RequestPlan::QueryPlan(query_plan) => { plan::execute_query_plan( - http_client, + http_context, query_plan, project_id, ) @@ -257,14 +265,14 @@ pub async fn execute_request_internal( let request_result = match request_plan { plan::RequestPlan::MutationPlan(mutation_plan) => { crate::execute::explain::explain_mutation_plan( - http_client, + http_context, mutation_plan, ) .await } plan::RequestPlan::QueryPlan(query_plan) => { crate::execute::explain::explain_query_plan( - http_client, + http_context, query_plan, ) .await diff --git a/v3/crates/engine/src/execute/error.rs b/v3/crates/engine/src/execute/error.rs index ba7d651db14..d2b053bc6a9 100644 --- a/v3/crates/engine/src/execute/error.rs +++ b/v3/crates/engine/src/execute/error.rs @@ -252,6 +252,9 @@ fn render_ndc_error(error: &ndc_client::Error) -> String { ndc_client::Error::InvalidConnectorError(invalid_connector_err) => { format!("invalid connector error: {0}", invalid_connector_err) } + ndc_client::Error::ResponseTooLarge(err) => { + format!("response received from connector is too large: {0}", err) + } } } diff --git a/v3/crates/engine/src/execute/explain.rs b/v3/crates/engine/src/execute/explain.rs index 66f5cfe0451..17b0e1e11df 100644 --- a/v3/crates/engine/src/execute/explain.rs +++ b/v3/crates/engine/src/execute/explain.rs @@ -1,5 +1,5 @@ use super::remote_joins::types::{JoinNode, RemoteJoinType}; -use super::ExecuteOrExplainResponse; +use super::{ExecuteOrExplainResponse, HttpContext}; use crate::execute::ndc::client as ndc_client; use crate::execute::plan::{ApolloFederationSelect, NodeQueryPlan, ProcessResponseAs}; use crate::execute::remote_joins::types::{JoinId, JoinLocations, RemoteJoin}; @@ -17,25 +17,25 @@ pub mod types; use lang_graphql::ast::common as ast; pub async fn execute_explain( - http_client: &reqwest::Client, + http_context: &HttpContext, schema: &Schema, session: &Session, request: RawRequest, ) -> types::ExplainResponse { - execute_explain_internal(http_client, schema, session, request) + execute_explain_internal(http_context, schema, session, request) .await .unwrap_or_else(|e| types::ExplainResponse::error(e.to_graphql_error(None))) } /// Explains a GraphQL query pub async fn execute_explain_internal( - http_client: &reqwest::Client, + http_context: &HttpContext, schema: &gql::schema::Schema, session: &Session, raw_request: gql::http::RawRequest, ) -> Result { let query_response = super::execute_request_internal( - http_client, + http_context, schema, session, raw_request, @@ -55,7 +55,7 @@ pub async fn execute_explain_internal( /// Produce an /explain plan for a given GraphQL query. pub(crate) async fn explain_query_plan( - http_client: &reqwest::Client, + http_context: &HttpContext, query_plan: plan::QueryPlan<'_, '_, '_>, ) -> Result { let mut parallel_root_steps = vec![]; @@ -64,7 +64,7 @@ pub(crate) async fn explain_query_plan( match node { NodeQueryPlan::NDCQueryExecution(ndc_query_execution) => { let sequence_steps = get_execution_steps( - http_client, + http_context, alias, &ndc_query_execution.process_response_as, ndc_query_execution.execution_tree.remote_executions, @@ -76,7 +76,7 @@ pub(crate) async fn explain_query_plan( } NodeQueryPlan::RelayNodeSelect(Some(ndc_query_execution)) => { let sequence_steps = get_execution_steps( - http_client, + http_context, alias, &ndc_query_execution.process_response_as, ndc_query_execution.execution_tree.remote_executions, @@ -92,7 +92,7 @@ pub(crate) async fn explain_query_plan( let mut parallel_steps = Vec::new(); for ndc_query_execution in parallel_ndc_query_executions { let sequence_steps = get_execution_steps( - http_client, + http_context, alias.clone(), &ndc_query_execution.process_response_as, ndc_query_execution.execution_tree.remote_executions, @@ -143,7 +143,7 @@ pub(crate) async fn explain_query_plan( /// Produce an /explain plan for a given GraphQL mutation. pub(crate) async fn explain_mutation_plan( - http_client: &reqwest::Client, + http_context: &HttpContext, mutation_plan: plan::MutationPlan<'_, '_, '_>, ) -> Result { let mut root_steps = vec![]; @@ -157,7 +157,7 @@ pub(crate) async fn explain_mutation_plan( for (_, mutation_group) in mutation_plan.nodes { for (alias, ndc_mutation_execution) in mutation_group { let sequence_steps = get_execution_steps( - http_client, + http_context, alias, &ndc_mutation_execution.process_response_as, ndc_mutation_execution.join_locations, @@ -182,7 +182,7 @@ pub(crate) async fn explain_mutation_plan( } async fn get_execution_steps<'s>( - http_client: &reqwest::Client, + http_context: &HttpContext, alias: gql::ast::common::Alias, process_response_as: &ProcessResponseAs<'s>, join_locations: JoinLocations<(RemoteJoin<'s, '_>, JoinId)>, @@ -192,9 +192,12 @@ async fn get_execution_steps<'s>( let mut sequence_steps = match process_response_as { ProcessResponseAs::CommandResponse { .. } => { // A command execution node - let data_connector_explain = - fetch_explain_from_data_connector(http_client, ndc_request.clone(), data_connector) - .await; + let data_connector_explain = fetch_explain_from_data_connector( + http_context, + ndc_request.clone(), + data_connector, + ) + .await; NonEmpty::new(Box::new(types::Step::CommandSelect( types::CommandSelectIR { command_name: alias.to_string(), @@ -205,9 +208,12 @@ async fn get_execution_steps<'s>( } ProcessResponseAs::Array { .. } | ProcessResponseAs::Object { .. } => { // A model execution node - let data_connector_explain = - fetch_explain_from_data_connector(http_client, ndc_request.clone(), data_connector) - .await; + let data_connector_explain = fetch_explain_from_data_connector( + http_context, + ndc_request.clone(), + data_connector, + ) + .await; NonEmpty::new(Box::new(types::Step::ModelSelect(types::ModelSelectIR { model_name: alias.to_string(), ndc_request, @@ -215,7 +221,8 @@ async fn get_execution_steps<'s>( }))) } }; - if let Some(join_steps) = get_join_steps(alias.to_string(), join_locations, http_client).await { + if let Some(join_steps) = get_join_steps(alias.to_string(), join_locations, http_context).await + { sequence_steps.push(Box::new(types::Step::Sequence(join_steps))); sequence_steps.push(Box::new(types::Step::HashJoin)); }; @@ -230,7 +237,7 @@ async fn get_execution_steps<'s>( async fn get_join_steps( _root_field_name: String, join_locations: JoinLocations<(RemoteJoin<'async_recursion, 'async_recursion>, JoinId)>, - http_client: &reqwest::Client, + http_context: &HttpContext, ) -> Option>> { let mut sequence_join_steps = vec![]; for (alias, location) in join_locations.locations { @@ -240,7 +247,7 @@ async fn get_join_steps( query_request.variables = Some(vec![]); let ndc_request = types::NDCRequest::Query(query_request); let data_connector_explain = fetch_explain_from_data_connector( - http_client, + http_context, ndc_request.clone(), remote_join.target_data_connector, ) @@ -265,7 +272,7 @@ async fn get_join_steps( }, ))) }; - if let Some(rest_join_steps) = get_join_steps(alias, location.rest, http_client).await { + if let Some(rest_join_steps) = get_join_steps(alias, location.rest, http_context).await { sequence_steps.push(Box::new(types::Step::Sequence(rest_join_steps))); sequence_steps.push(Box::new(types::Step::HashJoin)); }; @@ -306,7 +313,7 @@ fn simplify_step(step: Box) -> Box { } async fn fetch_explain_from_data_connector( - http_client: &reqwest::Client, + http_context: &HttpContext, ndc_request: types::NDCRequest, data_connector: &resolved::data_connector::DataConnectorLink, ) -> types::NDCExplainResponse { @@ -321,8 +328,9 @@ async fn fetch_explain_from_data_connector( base_path: data_connector.url.get_url(ast::OperationType::Query), user_agent: None, // This is isn't expensive, reqwest::Client is behind an Arc - client: http_client.clone(), + client: http_context.client.clone(), headers: data_connector.headers.0.clone(), + response_size_limit: http_context.ndc_response_size_limit, }; { // TODO: use capabilities from the data connector context diff --git a/v3/crates/engine/src/execute/ndc.rs b/v3/crates/engine/src/execute/ndc.rs index b9914b5ed82..e77a40a8d35 100644 --- a/v3/crates/engine/src/execute/ndc.rs +++ b/v3/crates/engine/src/execute/ndc.rs @@ -1,3 +1,5 @@ +pub mod response; + use axum::http::HeaderMap; use serde_json as json; @@ -9,7 +11,7 @@ use tracing_util::{set_attribute_on_active_span, AttributeVisibility, SpanVisibi use super::plan::ProcessResponseAs; use super::process_response::process_command_mutation_response; -use super::{error, ProjectId}; +use super::{error, HttpContext, ProjectId}; use crate::metadata::resolved; use crate::schema::GDS; @@ -19,7 +21,7 @@ pub const FUNCTION_IR_VALUE_COLUMN_NAME: &str = "__value"; /// Executes a NDC operation pub async fn execute_ndc_query<'n, 's>( - http_client: &reqwest::Client, + http_context: &HttpContext, query: ndc_models::QueryRequest, data_connector: &resolved::data_connector::DataConnectorLink, execution_span_attribute: String, @@ -44,7 +46,7 @@ pub async fn execute_ndc_query<'n, 's>( field_span_attribute, ); let connector_response = - fetch_from_data_connector(http_client, query, data_connector, project_id) + fetch_from_data_connector(http_context, query, data_connector, project_id) .await?; Ok(connector_response.0) }) @@ -54,7 +56,7 @@ pub async fn execute_ndc_query<'n, 's>( } pub(crate) async fn fetch_from_data_connector<'s>( - http_client: &reqwest::Client, + http_context: &HttpContext, query_request: ndc_models::QueryRequest, data_connector: &resolved::data_connector::DataConnectorLink, project_id: Option, @@ -72,8 +74,9 @@ pub(crate) async fn fetch_from_data_connector<'s>( base_path: data_connector.url.get_url(ast::OperationType::Query), user_agent: None, // This is isn't expensive, reqwest::Client is behind an Arc - client: http_client.clone(), + client: http_context.client.clone(), headers, + response_size_limit: http_context.ndc_response_size_limit, }; client::query_post(&ndc_config, query_request) .await @@ -104,7 +107,7 @@ pub fn append_project_id_to_headers( /// Executes a NDC mutation pub(crate) async fn execute_ndc_mutation<'n, 's, 'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, query: ndc_models::MutationRequest, data_connector: &resolved::data_connector::DataConnectorLink, selection_set: &'n normalized_ast::SelectionSet<'s, GDS>, @@ -131,7 +134,7 @@ pub(crate) async fn execute_ndc_mutation<'n, 's, 'ir>( field_span_attribute, ); let connector_response = fetch_from_data_connector_mutation( - http_client, + http_context, query, data_connector, project_id, @@ -173,7 +176,7 @@ pub(crate) async fn execute_ndc_mutation<'n, 's, 'ir>( } pub(crate) async fn fetch_from_data_connector_mutation<'s>( - http_client: &reqwest::Client, + http_context: &HttpContext, query_request: ndc_models::MutationRequest, data_connector: &resolved::data_connector::DataConnectorLink, project_id: Option, @@ -191,8 +194,9 @@ pub(crate) async fn fetch_from_data_connector_mutation<'s>( base_path: data_connector.url.get_url(ast::OperationType::Mutation), user_agent: None, // This is isn't expensive, reqwest::Client is behind an Arc - client: http_client.clone(), + client: http_context.client.clone(), headers, + response_size_limit: http_context.ndc_response_size_limit, }; client::mutation_post(&ndc_config, query_request) .await diff --git a/v3/crates/engine/src/execute/ndc/client.rs b/v3/crates/engine/src/execute/ndc/client.rs index 2c72b2fb8be..ad72b0b5dba 100644 --- a/v3/crates/engine/src/execute/ndc/client.rs +++ b/v3/crates/engine/src/execute/ndc/client.rs @@ -1,3 +1,4 @@ +use super::response::handle_response_with_size_limit; use ndc_client::models as ndc_models; use reqwest::header::{HeaderMap, HeaderValue}; use serde::{de::DeserializeOwned, Deserialize}; @@ -13,6 +14,7 @@ pub struct Configuration { pub user_agent: Option, pub client: reqwest::Client, pub headers: HeaderMap, + pub response_size_limit: Option, } /// Error type for the NDC API client interactions @@ -24,6 +26,7 @@ pub enum Error { ConnectorError(ConnectorError), InvalidConnectorError(InvalidConnectorError), InvalidBaseURL, + ResponseTooLarge(String), } impl fmt::Display for Error { @@ -35,6 +38,7 @@ impl fmt::Display for Error { Error::ConnectorError(e) => ("response", format!("status code {}", e.status)), Error::InvalidConnectorError(e) => ("response", format!("status code {}", e.status)), Error::InvalidBaseURL => ("url", "invalid base URL".into()), + Error::ResponseTooLarge(message) => ("response", format!("too large: {}", message)), }; write!(f, "error in {}: {}", module, e) } @@ -238,7 +242,11 @@ async fn execute_request( let resp = configuration.client.execute(request).await?; let response_status = resp.status(); - let response_content = resp.json().await?; + + let response_content = match configuration.response_size_limit { + None => resp.json().await?, + Some(size_limit) => handle_response_with_size_limit(resp, size_limit).await?, + }; if !response_status.is_client_error() && !response_status.is_server_error() { serde_json::from_value(response_content).map_err(Error::from) diff --git a/v3/crates/engine/src/execute/ndc/response.rs b/v3/crates/engine/src/execute/ndc/response.rs new file mode 100644 index 00000000000..66162f1b808 --- /dev/null +++ b/v3/crates/engine/src/execute/ndc/response.rs @@ -0,0 +1,140 @@ +use super::client as ndc_client; + +/// Handle response return from an NDC request by applying the size limit and +/// deserializing into a JSON value +pub(crate) async fn handle_response_with_size_limit( + response: reqwest::Response, + size_limit: usize, +) -> Result { + if let Some(content_length) = &response.content_length() { + // Check with content length + if *content_length > size_limit as u64 { + Err(ndc_client::Error::ResponseTooLarge(format!( + "Received content length {} exceeds the limit {}", + content_length, size_limit + ))) + } else { + Ok(response.json().await?) + } + } else { + // If no content length found, then check chunk-by-chunk + handle_response_by_chunks_with_size_limit(response, size_limit).await + } +} + +/// Handle response by chunks. For each chunk consumed, check if the total size exceeds the limit. +/// +/// This logic is separated in a function to allow testing. +async fn handle_response_by_chunks_with_size_limit( + response: reqwest::Response, + size_limit: usize, +) -> Result { + let mut size = 0; + let mut buf = bytes::BytesMut::new(); + let mut response = response; + while let Some(chunk) = response.chunk().await? { + size += chunk.len(); + if size > size_limit { + return Err(ndc_client::Error::ResponseTooLarge(format!( + "Size exceeds the limit {}", + size_limit + ))); + } else { + buf.extend_from_slice(&chunk); + } + } + Ok(serde_json::from_slice(&buf)?) +} + +#[cfg(test)] +mod test { + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn test_content_length() { + let mut server = mockito::Server::new_async().await; + let test_api = server + .mock("GET", "/test") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"message": "hello"}"#) + .create(); + let response = reqwest::get(server.url() + "/test").await.unwrap(); + test_api.assert(); + let err = super::handle_response_with_size_limit(response, 10) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "error in response: too large: Received content length 20 exceeds the limit 10" + ) + } + + #[tokio::test] + async fn test_chunk_by_chunk() { + let mut server = mockito::Server::new_async().await; + let test_api = server + .mock("GET", "/test") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"message": "hello"}"#) + .create(); + let response = reqwest::get(server.url() + "/test").await.unwrap(); + test_api.assert(); + let err = super::handle_response_by_chunks_with_size_limit(response, 5) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "error in response: too large: Size exceeds the limit 5" + ) + } + + #[tokio::test] + async fn test_success() { + let json = serde_json::json!( + [ + {"name": "Alice"}, + {"name": "Bob"}, + {"name": "Charlie"} + ] + ); + let mut server = mockito::Server::new_async().await; + let test_api = server + .mock("GET", "/test") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(serde_json::to_vec(&json).unwrap()) + .create(); + let response = reqwest::get(server.url() + "/test").await.unwrap(); + test_api.assert(); + let res = super::handle_response_with_size_limit(response, 100) + .await + .unwrap(); + assert_eq!(json, res) + } + + #[tokio::test] + async fn test_success_by_chunks() { + let json = serde_json::json!( + [ + {"name": "Alice"}, + {"name": "Bob"}, + {"name": "Charlie"} + ] + ); + let mut server = mockito::Server::new_async().await; + let test_api = server + .mock("GET", "/test") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(serde_json::to_vec(&json).unwrap()) + .create(); + let response = reqwest::get(server.url() + "/test").await.unwrap(); + test_api.assert(); + let res = super::handle_response_by_chunks_with_size_limit(response, 100) + .await + .unwrap(); + assert_eq!(json, res) + } +} diff --git a/v3/crates/engine/src/execute/plan.rs b/v3/crates/engine/src/execute/plan.rs index 040dbcd6b26..f7addd7263f 100644 --- a/v3/crates/engine/src/execute/plan.rs +++ b/v3/crates/engine/src/execute/plan.rs @@ -21,7 +21,7 @@ use super::remote_joins::execute_join_locations; use super::remote_joins::types::{ JoinId, JoinLocations, JoinNode, Location, LocationKind, MonotonicCounter, RemoteJoin, }; -use super::ProjectId; +use super::{HttpContext, ProjectId}; use crate::metadata::resolved::{self, subgraph}; use crate::schema::GDS; @@ -547,7 +547,7 @@ impl ExecuteQueryResult { /// Execute a single root field's query plan to produce a result. async fn execute_query_field_plan<'n, 's, 'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, query_plan: NodeQueryPlan<'n, 's, 'ir>, project_id: Option, ) -> RootFieldResult { @@ -603,13 +603,13 @@ async fn execute_query_field_plan<'n, 's, 'ir>( } NodeQueryPlan::NDCQueryExecution(ndc_query) => RootFieldResult::new( &ndc_query.process_response_as.is_nullable(), - resolve_ndc_query_execution(http_client, ndc_query, project_id).await, + resolve_ndc_query_execution(http_context, ndc_query, project_id).await, ), NodeQueryPlan::RelayNodeSelect(optional_query) => RootFieldResult::new( &optional_query.as_ref().map_or(true, |ndc_query| { ndc_query.process_response_as.is_nullable() }), - resolve_optional_ndc_select(http_client, optional_query, project_id) + resolve_optional_ndc_select(http_context, optional_query, project_id) .await, ), NodeQueryPlan::ApolloFederationSelect( @@ -622,7 +622,7 @@ async fn execute_query_field_plan<'n, 's, 'ir>( // To run the field plans parallely, we will need to use tokio::spawn for each field plan. let task = async { (resolve_optional_ndc_select( - http_client, + http_context, Some(query), project_id.clone(), ) @@ -694,7 +694,7 @@ async fn execute_query_field_plan<'n, 's, 'ir>( /// Execute a single root field's mutation plan to produce a result. async fn execute_mutation_field_plan<'n, 's, 'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, mutation_plan: NDCMutationExecution<'n, 's, 'ir>, project_id: Option, ) -> RootFieldResult { @@ -707,7 +707,7 @@ async fn execute_mutation_field_plan<'n, 's, 'ir>( Box::pin(async { RootFieldResult::new( &mutation_plan.process_response_as.is_nullable(), - resolve_ndc_mutation_execution(http_client, mutation_plan, project_id) + resolve_ndc_mutation_execution(http_context, mutation_plan, project_id) .await, ) }) @@ -720,7 +720,7 @@ async fn execute_mutation_field_plan<'n, 's, 'ir>( /// root fields of the mutation sequentially rather than concurrently, in the order defined by the /// `IndexMap`'s keys. pub async fn execute_mutation_plan<'n, 's, 'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, mutation_plan: MutationPlan<'n, 's, 'ir>, project_id: Option, ) -> ExecuteQueryResult { @@ -743,7 +743,7 @@ pub async fn execute_mutation_plan<'n, 's, 'ir>( for (alias, field_plan) in mutation_group { executed_root_fields.push(( alias, - execute_mutation_field_plan(http_client, field_plan, project_id.clone()).await, + execute_mutation_field_plan(http_context, field_plan, project_id.clone()).await, )); } } @@ -759,7 +759,7 @@ pub async fn execute_mutation_plan<'n, 's, 'ir>( /// Given an entire plan for a query, produce a result. We do this by executing all the singular /// root fields of the query in parallel, and joining the results back together. pub async fn execute_query_plan<'n, 's, 'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, query_plan: QueryPlan<'n, 's, 'ir>, project_id: Option, ) -> ExecuteQueryResult { @@ -773,7 +773,7 @@ pub async fn execute_query_plan<'n, 's, 'ir>( let task = async { ( alias, - execute_query_field_plan(http_client, field_plan, project_id.clone()).await, + execute_query_field_plan(http_context, field_plan, project_id.clone()).await, ) }; @@ -824,7 +824,7 @@ fn resolve_schema_field( } async fn resolve_ndc_query_execution( - http_client: &reqwest::Client, + http_context: &HttpContext, ndc_query: NDCQueryExecution<'_, '_>, project_id: Option, ) -> Result { @@ -836,7 +836,7 @@ async fn resolve_ndc_query_execution( process_response_as, } = ndc_query; let mut response = ndc::execute_ndc_query( - http_client, + http_context, execution_tree.root_node.query, execution_tree.root_node.data_connector, execution_span_attribute.clone(), @@ -847,7 +847,7 @@ async fn resolve_ndc_query_execution( // TODO: Failures in remote joins should result in partial response // https://github.com/hasura/v3-engine/issues/229 execute_join_locations( - http_client, + http_context, execution_span_attribute, field_span_attribute, &mut response, @@ -861,7 +861,7 @@ async fn resolve_ndc_query_execution( } async fn resolve_ndc_mutation_execution( - http_client: &reqwest::Client, + http_context: &HttpContext, ndc_query: NDCMutationExecution<'_, '_, '_>, project_id: Option, ) -> Result { @@ -876,7 +876,7 @@ async fn resolve_ndc_mutation_execution( join_locations: _, } = ndc_query; let response = ndc::execute_ndc_mutation( - http_client, + http_context, query, data_connector, selection_set, @@ -890,12 +890,12 @@ async fn resolve_ndc_mutation_execution( } async fn resolve_optional_ndc_select( - http_client: &reqwest::Client, + http_context: &HttpContext, optional_query: Option>, project_id: Option, ) -> Result { match optional_query { None => Ok(json::Value::Null), - Some(ndc_query) => resolve_ndc_query_execution(http_client, ndc_query, project_id).await, + Some(ndc_query) => resolve_ndc_query_execution(http_context, ndc_query, project_id).await, } } diff --git a/v3/crates/engine/src/execute/remote_joins.rs b/v3/crates/engine/src/execute/remote_joins.rs index de6a5124ea9..2758b320cf1 100644 --- a/v3/crates/engine/src/execute/remote_joins.rs +++ b/v3/crates/engine/src/execute/remote_joins.rs @@ -84,7 +84,7 @@ use tracing_util::SpanVisibility; use super::ndc::execute_ndc_query; use super::plan::ProcessResponseAs; -use super::{error, ProjectId}; +use super::{error, HttpContext, ProjectId}; use self::collect::CollectArgumentResult; use types::{Argument, JoinId, JoinLocations, RemoteJoin}; @@ -97,7 +97,7 @@ pub(crate) mod types; /// for the top-level query, and executes further remote joins recursively. #[async_recursion] pub(crate) async fn execute_join_locations<'ir>( - http_client: &reqwest::Client, + http_context: &HttpContext, execution_span_attribute: String, field_span_attribute: String, lhs_response: &mut Vec, @@ -138,7 +138,7 @@ where SpanVisibility::Internal, || { Box::pin(execute_ndc_query( - http_client, + http_context, join_node.target_ndc_ir, join_node.target_data_connector, execution_span_attribute.clone(), @@ -153,7 +153,7 @@ where // will modify the `target_response` with all joins down the tree if !location.rest.locations.is_empty() { execute_join_locations( - http_client, + http_context, execution_span_attribute.clone(), // TODO: is this field span correct? field_span_attribute.clone(), diff --git a/v3/crates/engine/tests/common.rs b/v3/crates/engine/tests/common.rs index d34a71f63fa..beeab2dc5e5 100644 --- a/v3/crates/engine/tests/common.rs +++ b/v3/crates/engine/tests/common.rs @@ -11,7 +11,7 @@ use std::{ path::PathBuf, }; -use engine::execute::execute_query; +use engine::execute::{execute_query, HttpContext}; use engine::schema::GDS; extern crate json_value_merge; @@ -19,14 +19,17 @@ use json_value_merge::Merge; use serde_json::Value; pub struct GoldenTestContext { - http_client: reqwest::Client, + http_context: HttpContext, mint: Mint, } pub fn setup(test_dir: &Path) -> GoldenTestContext { - let http_client = reqwest::Client::new(); + let http_context = HttpContext { + client: reqwest::Client::new(), + ndc_response_size_limit: None, + }; let mint = Mint::new(test_dir); - GoldenTestContext { http_client, mint } + GoldenTestContext { http_context, mint } } fn resolve_session( @@ -97,7 +100,7 @@ pub fn test_execution_expectation_legacy( // Execute the test let response = - execute_query(&test_ctx.http_client, &schema, &session, raw_request, None).await; + execute_query(&test_ctx.http_context, &schema, &session, raw_request, None).await; let mut expected = test_ctx.mint.new_goldenfile_with_differ( response_path, @@ -189,7 +192,7 @@ pub(crate) fn test_introspection_expectation( let mut responses = Vec::new(); for session in sessions.iter() { let response = execute_query( - &test_ctx.http_client, + &test_ctx.http_context, &schema, session, raw_request.clone(), @@ -288,7 +291,7 @@ pub fn test_execution_expectation( let mut responses = Vec::new(); for session in sessions.iter() { let response = execute_query( - &test_ctx.http_client, + &test_ctx.http_context, &schema, session, raw_request.clone(), @@ -370,7 +373,7 @@ pub fn test_execute_explain( variables: None, }; let raw_response = engine::execute::explain::execute_explain( - &test_ctx.http_client, + &test_ctx.http_context, &schema, &session, raw_request,