Fix bug when loading bucket from database

This commit is contained in:
Antonio Scandurra 2024-05-29 11:55:40 +02:00
parent 34b88d14f6
commit c19083a35c

View File

@ -62,7 +62,7 @@ impl RateLimiter {
let mut bucket = self
.buckets
.entry(bucket_key.clone())
.or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
.or_insert_with(|| RateBucket::new::<T>(now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
@ -72,19 +72,19 @@ impl RateLimiter {
}
}
async fn load_bucket<K: RateLimit>(
async fn load_bucket<T: RateLimit>(
&self,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
.get_rate_bucket(user_id, K::db_name())
.get_rate_bucket(user_id, T::db_name())
.await?
.map(|saved_bucket| RateBucket {
capacity: K::capacity(),
refill_time_per_token: K::refill_duration(),
token_count: saved_bucket.token_count as usize,
last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
.map(|saved_bucket| {
RateBucket::from_db::<T>(
saved_bucket.token_count as usize,
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
)
}))
}
@ -124,15 +124,24 @@ struct RateBucket {
}
impl RateBucket {
fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
RateBucket {
capacity,
token_count: capacity,
refill_time_per_token: refill_duration / capacity as i32,
fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
Self {
capacity: T::capacity(),
token_count: T::capacity(),
refill_time_per_token: T::refill_duration() / T::capacity() as i32,
last_refill: now,
}
}
fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
Self {
capacity: T::capacity(),
token_count,
refill_time_per_token: T::refill_duration() / T::capacity() as i32,
last_refill,
}
}
fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now);
if self.token_count > 0 {
@ -252,6 +261,17 @@ mod tests {
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
// After 1s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
}
struct RateLimitA;