mononoke/lfs_server: don't use closures for middleware, only use State

Summary: We no longer have any middleware that requires to be able to capture variables from the "inbound" phase (i.e. prior to handling the request) to the "outbound" phase (i.e. once the response is ready). Instead, we're passing everything through the State. So, let's get rid of the dynamism we don't need.

Reviewed By: HarveyHunt

Differential Revision: D17503373

fbshipit-source-id: 569d180250821aa3707245133a223b1f4efba3b6
This commit is contained in:
Thomas Orozco 2019-09-23 05:05:42 -07:00 committed by Facebook Github Bot
parent 65e3b6ac61
commit bfc2f9a144
8 changed files with 61 additions and 70 deletions

View File

@ -39,13 +39,10 @@ impl Handler for MononokeLfsHandler {
fn handle(self, mut state: State) -> Box<HandlerFuture> {
// On request, middleware is called in order, then called the other way around on response.
// This is what regular Router middleware in Gotham would do.
let mut callbacks: Vec<_> = self
.middleware
.iter()
.map(|m| m.handle(&mut state))
.collect();
callbacks.reverse();
let middleware = self.middleware.clone();
for m in middleware.iter() {
m.inbound(&mut state);
}
// NOTE: It's a bit unfortunate that we have to return a HandlerFuture here when really
// we'd rather be working with just (State, HttpResponse<Body>) everywhere, but that's how
@ -56,8 +53,8 @@ impl Handler for MononokeLfsHandler {
(state, response)
});
for callback in callbacks.into_iter() {
callback(&mut state, &mut response);
for m in middleware.iter().rev() {
m.outbound(&mut state, &mut response);
}
Ok((state, response))

View File

@ -8,14 +8,14 @@ use failure::Error;
use gotham::helpers::http::header::X_REQUEST_ID;
use gotham::state::{request_id, State};
use hyper::header::HeaderValue;
use hyper::{Body, Response};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use super::{Callback, Middleware};
use super::Middleware;
pub struct IdentityMiddleware {
headers: Arc<HashMap<&'static str, HeaderValue>>,
headers: HashMap<&'static str, HeaderValue>,
}
impl IdentityMiddleware {
@ -30,9 +30,7 @@ impl IdentityMiddleware {
let _ = Self::add_tw_task_version(&mut headers);
let _ = Self::add_tw_canary_id(&mut headers);
Self {
headers: Arc::new(headers),
}
Self { headers }
}
fn add_tw_task(headers: &mut HashMap<&'static str, HeaderValue>) -> Result<(), Error> {
@ -65,19 +63,15 @@ impl IdentityMiddleware {
}
impl Middleware for IdentityMiddleware {
fn handle(&self, _state: &mut State) -> Callback {
let headers_to_add = self.headers.clone();
fn outbound(&self, state: &mut State, response: &mut Response<Body>) {
let headers = response.headers_mut();
Box::new(move |state, response| {
let headers = response.headers_mut();
for (header, value) in self.headers.iter() {
headers.insert(*header, value.clone());
}
for (header, value) in headers_to_add.iter() {
headers.insert(*header, value.clone());
}
if let Ok(id) = HeaderValue::from_str(request_id(&state)) {
headers.insert(X_REQUEST_ID, id);
}
})
if let Ok(id) = HeaderValue::from_str(request_id(&state)) {
headers.insert(X_REQUEST_ID, id);
}
}
}

View File

@ -5,23 +5,21 @@
// GNU General Public License version 2 or any later version.
use gotham::state::{client_addr, request_id, FromState, State};
use hyper::{Body, Response};
use hyper::{Method, StatusCode, Uri, Version};
use slog::{info, Logger};
use std::sync::Arc;
use time_ext::DurationExt;
use super::{Callback, Middleware, RequestContext};
use super::{Middleware, RequestContext};
#[derive(Clone)]
pub struct LogMiddleware {
logger: Arc<Logger>,
logger: Logger,
}
impl LogMiddleware {
pub fn new(logger: Logger) -> Self {
Self {
logger: Arc::new(logger),
}
Self { logger }
}
}
@ -59,10 +57,7 @@ fn log_request(logger: &Logger, state: &State, status: &StatusCode) -> Option<()
}
impl Middleware for LogMiddleware {
fn handle(&self, _state: &mut State) -> Callback {
let logger = self.logger.clone();
Box::new(move |state, response| {
log_request(&logger, &state, &response.status());
})
fn outbound(&self, state: &mut State, response: &mut Response<Body>) {
log_request(&self.logger, &state, &response.status());
}
}

View File

@ -22,8 +22,14 @@ pub use self::request_context::{RequestContext, RequestContextMiddleware};
pub use self::scuba::{ScubaMiddleware, ScubaMiddlewareState};
pub use self::timer::TimerMiddleware;
pub type Callback = Box<dyn FnOnce(&mut State, &mut Response<Body>) + 'static + Send + Sync>;
pub trait Middleware: 'static + RefUnwindSafe + Send + Sync {
fn handle(&self, state: &mut State) -> Callback;
fn inbound(&self, _state: &mut State) {
// Implement inbound to perform pre-request actions, such as putting something in the
// state.
}
fn outbound(&self, _state: &mut State, _response: &mut Response<Body>) {
// Implement outbound to perform post-request actions, such as logging the response status
// code.
}
}

View File

@ -6,10 +6,11 @@
use gotham::state::State;
use hyper::StatusCode;
use hyper::{Body, Response};
use stats::{define_stats, DynamicHistogram, DynamicTimeseries};
use time_ext::DurationExt;
use super::{Callback, Middleware, RequestContext};
use super::{Middleware, RequestContext};
define_stats! {
prefix = "mononoke.lfs.request";
@ -53,9 +54,7 @@ impl OdsMiddleware {
}
impl Middleware for OdsMiddleware {
fn handle(&self, _state: &mut State) -> Callback {
Box::new(|state, response| {
log_stats(state, response.status());
})
fn outbound(&self, state: &mut State, response: &mut Response<Body>) {
log_stats(state, response.status());
}
}

View File

@ -7,13 +7,14 @@
use futures::{Future, IntoFuture};
use gotham::state::State;
use gotham_derive::StateData;
use hyper::{Body, Response};
use std::time::{Duration, Instant};
use tokio::{
self,
sync::oneshot::{channel, Receiver, Sender},
};
use super::{Callback, Middleware};
use super::Middleware;
type PostRequestCallback = Box<dyn FnOnce(&Duration) + Sync + Send + 'static>;
@ -116,13 +117,13 @@ impl RequestContextMiddleware {
}
impl Middleware for RequestContextMiddleware {
fn handle(&self, state: &mut State) -> Callback {
fn inbound(&self, state: &mut State) {
state.put(RequestContext::new());
}
Box::new(|state, _response| {
if let Some(ctx) = state.try_take::<RequestContext>() {
ctx.dispatch_post_request();
}
})
fn outbound(&self, state: &mut State, _response: &mut Response<Body>) {
if let Some(ctx) = state.try_take::<RequestContext>() {
ctx.dispatch_post_request();
}
}
}

View File

@ -10,12 +10,13 @@ use hyper::{
header::{self, AsHeaderName, HeaderMap},
Method, StatusCode, Uri,
};
use hyper::{Body, Response};
use json_encoded::get_identities;
use percent_encoding::percent_decode;
use scuba::{ScubaSampleBuilder, ScubaValue};
use time_ext::DurationExt;
use super::{Callback, Middleware, RequestContext};
use super::{Middleware, RequestContext};
const ENCODED_CLIENT_IDENTITY: &str = "x-fb-validated-client-encoded-identity";
const CLIENT_IP: &str = "tfb-orig-client-ip";
@ -125,17 +126,17 @@ impl ScubaMiddlewareState {
}
impl Middleware for ScubaMiddleware {
fn handle(&self, state: &mut State) -> Callback {
fn inbound(&self, state: &mut State) {
state.put(ScubaMiddlewareState(self.scuba.clone()));
}
Box::new(|state, response| {
if let Some(uri) = Uri::try_borrow_from(&state) {
if uri.path() == "/health_check" {
return;
}
fn outbound(&self, state: &mut State, response: &mut Response<Body>) {
if let Some(uri) = Uri::try_borrow_from(&state) {
if uri.path() == "/health_check" {
return;
}
}
log_stats(state, &response.status());
})
log_stats(state, &response.status());
}
}

View File

@ -5,8 +5,9 @@
// GNU General Public License version 2 or any later version.
use gotham::state::State;
use hyper::{Body, Response};
use super::{Callback, Middleware, RequestContext};
use super::{Middleware, RequestContext};
#[derive(Clone)]
pub struct TimerMiddleware {}
@ -18,12 +19,9 @@ impl TimerMiddleware {
}
impl Middleware for TimerMiddleware {
fn handle(&self, _state: &mut State) -> Callback {
// TODO: Rework the Callback stuff...
Box::new(move |state, _response| {
if let Some(ctx) = state.try_borrow_mut::<RequestContext>() {
ctx.headers_ready();
}
})
fn outbound(&self, state: &mut State, _response: &mut Response<Body>) {
if let Some(ctx) = state.try_borrow_mut::<RequestContext>() {
ctx.headers_ready();
}
}
}