gotham_ext: allow applications to dynamically configure PostRequestMiddleware

Summary: Make `PostRequestMiddleware` generic over a user-provided config struct which can be used to dynamically configure the behavior of post-request callback dispatching. Right now this is only used to support disabling hostname logging, but could be easily extended to cover more uses in the future.

Reviewed By: krallin

Differential Revision: D23495005

fbshipit-source-id: 3d59a8346f449775ec76d03c260d973d04fb90a9
This commit is contained in:
Arun Kulshreshtha 2020-09-03 11:56:33 -07:00 committed by Facebook GitHub Bot
parent cc0f2e4c40
commit 3ad7fa8b6f
3 changed files with 43 additions and 18 deletions

View File

@ -8,6 +8,7 @@ include = ["src/**/*.rs"]
[dependencies]
permission_checker = { path = "../permission_checker" }
cached_config = { git = "https://github.com/facebookexperimental/rust-shed.git", branch = "master" }
anyhow = "1.0"
async-trait = "0.1.29"
bytes = { version = "0.5", features = ["serde"] }

View File

@ -17,7 +17,7 @@ pub mod timer;
pub mod tls_session_data;
pub use client_identity::{ClientIdentity, ClientIdentityMiddleware};
pub use post_request::{PostRequestCallbacks, PostRequestMiddleware};
pub use post_request::{PostRequestCallbacks, PostRequestConfig, PostRequestMiddleware};
pub use server_identity::ServerIdentityMiddleware;
pub use timer::{HeadersDuration, RequestStartTime, TimerMiddleware};
pub use tls_session_data::TlsSessionDataMiddleware;

View File

@ -5,9 +5,11 @@
* GNU General Public License version 2.
*/
use std::panic::RefUnwindSafe;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use cached_config::ConfigHandle;
use futures::{
channel::oneshot::{self, Receiver, Sender},
prelude::*,
@ -30,31 +32,59 @@ pub struct PostRequestInfo {
pub client_hostname: Option<String>,
}
/// Trait allowing post-request callbacks to be configured dynamically.
pub trait PostRequestConfig: Clone + Send + Sync + RefUnwindSafe + 'static {
/// Specify whether the middleware should perform a potentially
/// expensive reverse DNS lookup of the client's hostname.
fn resolve_hostname(&self) -> bool {
true
}
}
#[derive(Clone)]
pub struct DefaultConfig;
impl PostRequestConfig for DefaultConfig {}
impl<C: PostRequestConfig> PostRequestConfig for ConfigHandle<C> {
fn resolve_hostname(&self) -> bool {
self.get().resolve_hostname()
}
}
/// Middleware that allows the application to register callbacks which will
/// be run upon request completion.
#[derive(Clone)]
pub struct PostRequestMiddleware;
pub struct PostRequestMiddleware<C> {
config: C,
}
impl PostRequestMiddleware {
pub fn new() -> Self {
PostRequestMiddleware
impl<C> PostRequestMiddleware<C> {
pub fn with_config(config: C) -> Self {
Self { config }
}
}
impl Default for PostRequestMiddleware<DefaultConfig> {
fn default() -> Self {
PostRequestMiddleware::with_config(DefaultConfig)
}
}
#[async_trait]
impl Middleware for PostRequestMiddleware {
impl<C: PostRequestConfig> Middleware for PostRequestMiddleware<C> {
async fn inbound(&self, state: &mut State) -> Option<Response<Body>> {
state.put(PostRequestCallbacks::new());
None
}
async fn outbound(&self, state: &mut State, _response: &mut Response<Body>) {
let config = self.config.clone();
let start_time = RequestStartTime::try_borrow_from(&state).map(|t| t.0);
let content_length = ResponseContentLength::try_borrow_from(&state).map(|l| l.0);
let hostname_future = ClientIdentity::try_borrow_from(&state).map(|id| id.hostname());
if let Some(callbacks) = state.try_take::<PostRequestCallbacks>() {
task::spawn(callbacks.run(start_time, content_length, hostname_future));
task::spawn(callbacks.run(config, start_time, content_length, hostname_future));
}
}
}
@ -64,7 +94,6 @@ impl Middleware for PostRequestMiddleware {
pub struct PostRequestCallbacks {
callbacks: Vec<Callback>,
delay_signal: Option<Receiver<u64>>,
resolve_hostname: bool,
}
impl PostRequestCallbacks {
@ -72,7 +101,6 @@ impl PostRequestCallbacks {
Self {
callbacks: Vec::new(),
delay_signal: None,
resolve_hostname: true,
}
}
@ -103,23 +131,19 @@ impl PostRequestCallbacks {
sender
}
/// Enable or disable reverse DNS lookup of the client's hostname.
pub fn resolve_hostname(&mut self, enable: bool) {
self.resolve_hostname = enable;
}
async fn run<H>(
async fn run<C, H>(
self,
config: C,
start_time: Option<Instant>,
content_length: Option<u64>,
hostname_future: Option<H>,
) where
C: PostRequestConfig,
H: Future<Output = Option<String>> + Send + 'static,
{
let Self {
callbacks,
delay_signal,
resolve_hostname,
} = self;
// If a delay has been set, wait until the entire response has been
@ -137,7 +161,7 @@ impl PostRequestCallbacks {
// Resolve client hostname if enabled.
let client_hostname = match hostname_future {
Some(hostname) if resolve_hostname => hostname.await,
Some(hostname) if config.resolve_hostname() => hostname.await,
_ => None,
};