http-client: add a callback for all requests

Summary:
It turns out there are multiple users sending requests bypassing the
HttpClient, like the LFS in revisionstore, or the segmented changelog
clone.

Requests bypassing HttpClient means HttpClient event callbacks do not
have a chance to insert progress and bandwidth monitoring. So let's
add another callback that can capture what HttpClient misses. This would allow
us to get proper progress bars of revisionstore LFS and segmented clone without
changing their code.

Reviewed By: andll

Differential Revision: D26970748

fbshipit-source-id: 5133bc6f9eeb14a6d2944d253bc66cefd49c83c5
This commit is contained in:
Jun Wu 2021-03-11 17:16:35 -08:00 committed by Facebook GitHub Bot
parent 44df77ef6b
commit 2f032a420a
2 changed files with 50 additions and 1 deletions

View File

@ -101,6 +101,14 @@ gen_event_listeners! {
}
}
gen_event_listeners! {
/// Events for request creation (both independent requests and requests via `HttpClient`)
RequestCreationEventListeners {
/// A request is created.
new_request(req: &mut RequestContext),
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicUsize;

View File

@ -16,11 +16,14 @@ use curl::{
self,
easy::{Easy2, HttpVersion, List},
};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use serde::Serialize;
use url::Url;
use crate::{
errors::HttpClientError,
event_listeners::RequestCreationEventListeners,
event_listeners::RequestEventListeners,
handler::{Buffered, HandlerExt, Streaming},
receiver::{ChannelReceiver, Receiver},
@ -87,6 +90,9 @@ pub struct Request {
min_transfer_speed: Option<MinTransferSpeed>,
}
static REQUEST_CREATION_LISTENERS: Lazy<RwLock<RequestCreationEventListeners>> =
Lazy::new(Default::default);
impl RequestContext {
/// Create a [`RequestContext`].
pub fn new(url: Url, method: Method) -> Self {
@ -302,9 +308,12 @@ impl Request {
/// Turn this `Request` into a `curl::Easy2` handle using
/// the given `Handler` to process the response.
pub(crate) fn into_handle<H: HandlerExt>(
self,
mut self,
create_handler: impl FnOnce(RequestContext) -> H,
) -> Result<Easy2<H>, HttpClientError> {
REQUEST_CREATION_LISTENERS
.read()
.trigger_new_request(&mut self.ctx);
let body_size = self.ctx.body.as_ref().map(|body| body.len() as u64);
let url = self.ctx.url.clone();
let handler = create_handler(self.ctx);
@ -378,6 +387,11 @@ impl Request {
Ok(easy)
}
/// Register a callback function that is called on new requests.
pub fn on_new_request(f: impl Fn(&mut RequestContext) + Send + Sync + 'static) {
REQUEST_CREATION_LISTENERS.write().on_new_request(f);
}
}
impl TryFrom<Request> for Easy2<Buffered> {
@ -419,6 +433,9 @@ impl<R: Receiver> TryFrom<StreamRequest<R>> for Easy2<Streaming<R>> {
mod tests {
use super::*;
use std::sync::atomic::Ordering::Acquire;
use std::sync::Arc;
use anyhow::Result;
use futures::TryStreamExt;
use http::{
@ -656,4 +673,28 @@ mod tests {
let req2 = RequestContext::dummy();
assert_ne!(req.id(), req2.id());
}
#[test]
fn test_request_callback() -> Result<()> {
let called = Arc::new(AtomicUsize::new(0));
Request::on_new_request({
let called = called.clone();
move |req| {
// The callback can receive requests in other tests.
// So we need to check the request is sent by this test.
if req.url().path() == "/test_callback" {
called.fetch_add(1, AcqRel);
}
}
});
let mock = mock("HEAD", "/test_callback").with_status(200).create();
let url = Url::parse(&mockito::server_url())?.join("test_callback")?;
let _res = Request::head(url).send()?;
mock.assert();
assert_eq!(called.load(Acquire), 1);
Ok(())
}
}