diff --git a/eden/mononoke/edenapi_server/src/main.rs b/eden/mononoke/edenapi_server/src/main.rs index 7dfb45dd41..c046d0b165 100644 --- a/eden/mononoke/edenapi_server/src/main.rs +++ b/eden/mononoke/edenapi_server/src/main.rs @@ -125,7 +125,7 @@ async fn start( let readonly_storage = args::parse_readonly_storage(&matches); let blobstore_options = args::parse_blobstore_options(&matches)?; let disabled_hooks = args::parse_disabled_hooks_with_repo_prefix(&matches, &logger)?; - let trusted_proxy_idents = parse_identities(&matches)?; + let trusted_proxy_idents = Arc::new(parse_identities(&matches)?); let tls_session_data_log = matches.value_of(ARG_TLS_SESSION_DATA_LOG_FILE); let mut scuba_logger = args::get_scuba_sample_builder(fb, &matches, &logger)?; @@ -167,7 +167,7 @@ async fn start( let router = build_router(ctx); let handler = MononokeHttpHandler::builder() .add(TlsSessionDataMiddleware::new(tls_session_data_log)?) - .add(ClientIdentityMiddleware::new(trusted_proxy_idents)) + .add(ClientIdentityMiddleware::new()) .add(ServerIdentityMiddleware::new(HeaderValue::from_static( "edenapi_server", ))) @@ -197,7 +197,7 @@ async fn start( bind_server_with_socket_data(listener, handler, { cloned!(logger); move |socket| { - cloned!(acceptor, logger); + cloned!(acceptor, logger, trusted_proxy_idents); async move { let ssl_socket = match tokio_openssl::accept(&acceptor, socket).await { Ok(ssl_socket) => ssl_socket, @@ -207,8 +207,11 @@ async fn start( } }; - let socket_data = - TlsSocketData::from_ssl(ssl_socket.ssl(), capture_session_data); + let socket_data = TlsSocketData::from_ssl( + ssl_socket.ssl(), + trusted_proxy_idents.as_ref(), + capture_session_data, + ); Ok((socket_data, ssl_socket)) } diff --git a/eden/mononoke/gotham_ext/src/middleware/client_identity.rs b/eden/mononoke/gotham_ext/src/middleware/client_identity.rs index d1841a9f90..2af480e3b3 100644 --- a/eden/mononoke/gotham_ext/src/middleware/client_identity.rs +++ b/eden/mononoke/gotham_ext/src/middleware/client_identity.rs @@ -14,7 +14,6 @@ use lazy_static::lazy_static; use percent_encoding::percent_decode; use permission_checker::{MononokeIdentity, MononokeIdentitySet}; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; use trust_dns_resolver::TokioAsyncResolver; use super::Middleware; @@ -87,27 +86,21 @@ impl ClientIdentity { } #[derive(Clone)] -pub struct ClientIdentityMiddleware { - trusted_proxy_allowlist: Arc, -} +pub struct ClientIdentityMiddleware; impl ClientIdentityMiddleware { - pub fn new(trusted_proxy_idents: MononokeIdentitySet) -> Self { - Self { - trusted_proxy_allowlist: Arc::new(trusted_proxy_idents), - } + pub fn new() -> Self { + Self } fn extract_client_identities( &self, - cert_idents: MononokeIdentitySet, + tls_certificate_identities: TlsCertificateIdentities, headers: &HeaderMap, ) -> Option { - let is_trusted_proxy = !self.trusted_proxy_allowlist.is_disjoint(&cert_idents); - if is_trusted_proxy { - request_identities_from_headers(&headers) - } else { - Some(cert_idents) + match tls_certificate_identities { + TlsCertificateIdentities::TrustedProxy => request_identities_from_headers(&headers), + TlsCertificateIdentities::Authenticated(idents) => Some(idents), } } } @@ -144,8 +137,7 @@ impl Middleware for ClientIdentityMiddleware { client_identity.client_correlator = request_client_correlator_from_headers(&headers); if let Some(cert_idents) = cert_idents { - client_identity.identities = - self.extract_client_identities(cert_idents.identities, &headers); + client_identity.identities = self.extract_client_identities(cert_idents, &headers); } } diff --git a/eden/mononoke/gotham_ext/src/socket_data.rs b/eden/mononoke/gotham_ext/src/socket_data.rs index c9d76e966d..83f9d4db5b 100644 --- a/eden/mononoke/gotham_ext/src/socket_data.rs +++ b/eden/mononoke/gotham_ext/src/socket_data.rs @@ -17,8 +17,12 @@ pub struct TlsSocketData { } impl TlsSocketData { - pub fn from_ssl(ssl: &SslRef, capture_session_data: bool) -> Self { - let identities = TlsCertificateIdentities::from_ssl(ssl); + pub fn from_ssl( + ssl: &SslRef, + trusted_proxy_allowlist: &MononokeIdentitySet, + capture_session_data: bool, + ) -> Self { + let identities = TlsCertificateIdentities::from_ssl(ssl, trusted_proxy_allowlist); let session_data = if capture_session_data { TlsSessionData::from_ssl(ssl) @@ -78,15 +82,22 @@ impl TlsSessionData { } #[derive(Clone, StateData)] -pub struct TlsCertificateIdentities { - pub identities: MononokeIdentitySet, +pub enum TlsCertificateIdentities { + TrustedProxy, + Authenticated(MononokeIdentitySet), } impl TlsCertificateIdentities { - pub fn from_ssl(ssl: &SslRef) -> Option { + pub fn from_ssl(ssl: &SslRef, trusted_proxy_allowlist: &MononokeIdentitySet) -> Option { let peer_certificate = ssl.peer_certificate()?; - Some(Self { - identities: MononokeIdentity::try_from_x509(&peer_certificate).ok()?, - }) + let identities = MononokeIdentity::try_from_x509(&peer_certificate).ok()?; + + let ret = if trusted_proxy_allowlist.is_disjoint(&identities) { + TlsCertificateIdentities::Authenticated(identities) + } else { + TlsCertificateIdentities::TrustedProxy + }; + + Some(ret) } } diff --git a/eden/mononoke/lfs_server/src/main.rs b/eden/mononoke/lfs_server/src/main.rs index 0c6ed525eb..ccd3ff9048 100644 --- a/eden/mononoke/lfs_server/src/main.rs +++ b/eden/mononoke/lfs_server/src/main.rs @@ -247,7 +247,9 @@ fn main(fb: FacebookInit) -> Result<(), Error> { let mut scuba_logger = args::get_scuba_sample_builder(fb, &matches, &logger)?; - let trusted_proxy_idents = idents_from_values(matches.values_of(ARG_TRUSTED_PROXY_IDENTITY))?; + let trusted_proxy_idents = Arc::new(idents_from_values( + matches.values_of(ARG_TRUSTED_PROXY_IDENTITY), + )?); scuba_logger.add_common_server_data(); @@ -353,7 +355,7 @@ fn main(fb: FacebookInit) -> Result<(), Error> { let handler = MononokeHttpHandler::builder() .add(TlsSessionDataMiddleware::new(tls_session_data_log)?) - .add(ClientIdentityMiddleware::new(trusted_proxy_idents)) + .add(ClientIdentityMiddleware::new()) .add(PostRequestMiddleware::with_config(config_handle)) .add(RequestContextMiddleware::new(fb, logger.clone())) .add(LoadMiddleware::new()) @@ -403,7 +405,7 @@ fn main(fb: FacebookInit) -> Result<(), Error> { listener .incoming() .for_each(move |socket| { - cloned!(acceptor, logger, protocol, handler); + cloned!(acceptor, logger, protocol, handler, trusted_proxy_idents); let task = async move { let socket = socket.context("Error obtaining socket")?; @@ -412,8 +414,11 @@ fn main(fb: FacebookInit) -> Result<(), Error> { .await .context("Error performing TLS handshake")?; - let socket_data = - TlsSocketData::from_ssl(ssl_socket.ssl(), capture_session_data); + let socket_data = TlsSocketData::from_ssl( + ssl_socket.ssl(), + trusted_proxy_idents.as_ref(), + capture_session_data, + ); let service = ConnectedGothamService::connect(handler, addr, socket_data);