diff --git a/eden/scm/lib/http-client/src/event_listeners.rs b/eden/scm/lib/http-client/src/event_listeners.rs index fa5d227b32..f7931fd3ac 100644 --- a/eden/scm/lib/http-client/src/event_listeners.rs +++ b/eden/scm/lib/http-client/src/event_listeners.rs @@ -101,6 +101,14 @@ gen_event_listeners! { } } +gen_event_listeners! { + /// Events for request creation (both independent requests and requests via `HttpClient`) + RequestCreationEventListeners { + /// A request is created. + new_request(req: &mut RequestContext), + } +} + #[cfg(test)] mod tests { use std::sync::atomic::AtomicUsize; diff --git a/eden/scm/lib/http-client/src/request.rs b/eden/scm/lib/http-client/src/request.rs index eb7060e87c..725b45c3cc 100644 --- a/eden/scm/lib/http-client/src/request.rs +++ b/eden/scm/lib/http-client/src/request.rs @@ -16,11 +16,14 @@ use curl::{ self, easy::{Easy2, HttpVersion, List}, }; +use once_cell::sync::Lazy; +use parking_lot::RwLock; use serde::Serialize; use url::Url; use crate::{ errors::HttpClientError, + event_listeners::RequestCreationEventListeners, event_listeners::RequestEventListeners, handler::{Buffered, HandlerExt, Streaming}, receiver::{ChannelReceiver, Receiver}, @@ -87,6 +90,9 @@ pub struct Request { min_transfer_speed: Option, } +static REQUEST_CREATION_LISTENERS: Lazy> = + Lazy::new(Default::default); + impl RequestContext { /// Create a [`RequestContext`]. pub fn new(url: Url, method: Method) -> Self { @@ -302,9 +308,12 @@ impl Request { /// Turn this `Request` into a `curl::Easy2` handle using /// the given `Handler` to process the response. pub(crate) fn into_handle( - self, + mut self, create_handler: impl FnOnce(RequestContext) -> H, ) -> Result, HttpClientError> { + REQUEST_CREATION_LISTENERS + .read() + .trigger_new_request(&mut self.ctx); let body_size = self.ctx.body.as_ref().map(|body| body.len() as u64); let url = self.ctx.url.clone(); let handler = create_handler(self.ctx); @@ -378,6 +387,11 @@ impl Request { Ok(easy) } + + /// Register a callback function that is called on new requests. + pub fn on_new_request(f: impl Fn(&mut RequestContext) + Send + Sync + 'static) { + REQUEST_CREATION_LISTENERS.write().on_new_request(f); + } } impl TryFrom for Easy2 { @@ -419,6 +433,9 @@ impl TryFrom> for Easy2> { mod tests { use super::*; + use std::sync::atomic::Ordering::Acquire; + use std::sync::Arc; + use anyhow::Result; use futures::TryStreamExt; use http::{ @@ -656,4 +673,28 @@ mod tests { let req2 = RequestContext::dummy(); assert_ne!(req.id(), req2.id()); } + + #[test] + fn test_request_callback() -> Result<()> { + let called = Arc::new(AtomicUsize::new(0)); + Request::on_new_request({ + let called = called.clone(); + move |req| { + // The callback can receive requests in other tests. + // So we need to check the request is sent by this test. + if req.url().path() == "/test_callback" { + called.fetch_add(1, AcqRel); + } + } + }); + + let mock = mock("HEAD", "/test_callback").with_status(200).create(); + let url = Url::parse(&mockito::server_url())?.join("test_callback")?; + let _res = Request::head(url).send()?; + + mock.assert(); + assert_eq!(called.load(Acquire), 1); + + Ok(()) + } }