mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
collab: Adapt rate limits based on plan (#15548)
This PR updates the rate limits to adapt based on the user's current plan. For the free plan rate limits I just took one-tenth of the existing rate limits (which are now the Pro limits). We can adjust, as needed. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
7a0149f17c
commit
8c54a46202
@ -169,9 +169,7 @@ struct ManageBillingSubscriptionBody {
|
||||
github_user_id: i32,
|
||||
intent: ManageSubscriptionIntent,
|
||||
/// The ID of the subscription to manage.
|
||||
///
|
||||
/// If not provided, we will try to use the active subscription (if there is only one).
|
||||
subscription_id: Option<BillingSubscriptionId>,
|
||||
subscription_id: BillingSubscriptionId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@ -206,23 +204,11 @@ async fn manage_billing_subscription(
|
||||
let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?;
|
||||
|
||||
let subscription = if let Some(subscription_id) = body.subscription_id {
|
||||
app.db
|
||||
.get_billing_subscription_by_id(subscription_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("subscription not found"))?
|
||||
} else {
|
||||
// If no subscription ID was provided, try to find the only active subscription ID.
|
||||
let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
|
||||
if subscriptions.len() > 1 {
|
||||
Err(anyhow!("user has multiple active subscriptions"))?;
|
||||
}
|
||||
|
||||
subscriptions
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("user has no active subscriptions"))?
|
||||
};
|
||||
let subscription = app
|
||||
.db
|
||||
.get_billing_subscription_by_id(body.subscription_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("subscription not found"))?;
|
||||
|
||||
let flow = match body.intent {
|
||||
ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
|
||||
|
@ -110,13 +110,15 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns all of the active billing subscriptions for the user with the specified ID.
|
||||
pub async fn get_active_billing_subscriptions(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Vec<billing_subscription::Model>> {
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
}
|
||||
|
||||
/// Returns the count of the active billing subscriptions for the user with the specified ID.
|
||||
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
|
||||
self.transaction(|tx| async move {
|
||||
let subscriptions = billing_subscription::Entity::find()
|
||||
let count = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(
|
||||
billing_customer::Column::UserId.eq(user_id).and(
|
||||
@ -124,11 +126,10 @@ impl Database {
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
),
|
||||
)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.all(&*tx)
|
||||
.count(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(subscriptions)
|
||||
Ok(count as usize)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -17,9 +17,12 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
|
||||
// A user with no subscription has no active billing subscriptions.
|
||||
{
|
||||
let user_id = new_test_user(db, "no-subscription-user@example.com").await;
|
||||
let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
|
||||
let subscription_count = db
|
||||
.count_active_billing_subscriptions(user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(subscriptions.len(), 0);
|
||||
assert_eq!(subscription_count, 0);
|
||||
}
|
||||
|
||||
// A user with an active subscription has one active billing subscription.
|
||||
@ -42,7 +45,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
|
||||
let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
|
||||
assert_eq!(subscriptions.len(), 1);
|
||||
|
||||
let subscription = &subscriptions[0];
|
||||
@ -76,7 +79,10 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let subscriptions = db.get_active_billing_subscriptions(user_id).await.unwrap();
|
||||
assert_eq!(subscriptions.len(), 0);
|
||||
let subscription_count = db
|
||||
.count_active_billing_subscriptions(user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(subscription_count, 0);
|
||||
}
|
||||
}
|
||||
|
@ -6,10 +6,10 @@ use sea_orm::prelude::DateTimeUtc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait RateLimit: 'static {
|
||||
fn capacity() -> usize;
|
||||
fn refill_duration() -> Duration;
|
||||
fn db_name() -> &'static str;
|
||||
pub trait RateLimit: Send + Sync {
|
||||
fn capacity(&self) -> usize;
|
||||
fn refill_duration(&self) -> Duration;
|
||||
fn db_name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
/// Used to enforce per-user rate limits
|
||||
@ -42,18 +42,23 @@ impl RateLimiter {
|
||||
|
||||
/// Returns an error if the user has exceeded the specified `RateLimit`.
|
||||
/// Attempts to read the from the database if no cached RateBucket currently exists.
|
||||
pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
|
||||
self.check_internal::<T>(user_id, Utc::now()).await
|
||||
pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
|
||||
self.check_internal(limit, user_id, Utc::now()).await
|
||||
}
|
||||
|
||||
async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
|
||||
let bucket_key = (user_id, T::db_name().to_string());
|
||||
async fn check_internal(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<()> {
|
||||
let bucket_key = (user_id, limit.db_name().to_string());
|
||||
|
||||
// Attempt to fetch the bucket from the database if it hasn't been cached.
|
||||
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
|
||||
// but this enforces limits across restarts so long as the database is reachable.
|
||||
if !self.buckets.contains_key(&bucket_key) {
|
||||
if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
|
||||
if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
|
||||
self.buckets.insert(bucket_key.clone(), bucket);
|
||||
self.dirty_buckets.insert(bucket_key.clone());
|
||||
}
|
||||
@ -62,7 +67,7 @@ impl RateLimiter {
|
||||
let mut bucket = self
|
||||
.buckets
|
||||
.entry(bucket_key.clone())
|
||||
.or_insert_with(|| RateBucket::new::<T>(now));
|
||||
.or_insert_with(|| RateBucket::new(limit, now));
|
||||
|
||||
if bucket.value_mut().allow(now) {
|
||||
self.dirty_buckets.insert(bucket_key);
|
||||
@ -72,16 +77,18 @@ impl RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_bucket<T: RateLimit>(
|
||||
async fn load_bucket(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<RateBucket>, Error> {
|
||||
Ok(self
|
||||
.db
|
||||
.get_rate_bucket(user_id, T::db_name())
|
||||
.get_rate_bucket(user_id, limit.db_name())
|
||||
.await?
|
||||
.map(|saved_bucket| {
|
||||
RateBucket::from_db::<T>(
|
||||
RateBucket::from_db(
|
||||
limit,
|
||||
saved_bucket.token_count as usize,
|
||||
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
|
||||
)
|
||||
@ -124,20 +131,20 @@ struct RateBucket {
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
|
||||
fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: T::capacity(),
|
||||
token_count: T::capacity(),
|
||||
refill_time_per_token: T::refill_duration() / T::capacity() as i32,
|
||||
capacity: limit.capacity(),
|
||||
token_count: limit.capacity(),
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
|
||||
fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: T::capacity(),
|
||||
capacity: limit.capacity(),
|
||||
token_count,
|
||||
refill_time_per_token: T::refill_duration() / T::capacity() as i32,
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill,
|
||||
}
|
||||
}
|
||||
@ -205,50 +212,52 @@ mod tests {
|
||||
let mut now = Utc::now();
|
||||
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
let rate_limit_a = Box::new(RateLimitA);
|
||||
let rate_limit_b = Box::new(RateLimitB);
|
||||
|
||||
// User 1 can access resource A two times before being rate-limited.
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// User 2 can access resource A and user 1 can access resource B.
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitB>(user_2, now)
|
||||
.check_internal(&*rate_limit_b, user_2, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitB>(user_1, now)
|
||||
.check_internal(&*rate_limit_b, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// After 1.5s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(1500);
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 500ms, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(500);
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
@ -258,18 +267,18 @@ mod tests {
|
||||
// for resource A.
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 1s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::seconds(1);
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
@ -277,15 +286,15 @@ mod tests {
|
||||
struct RateLimitA;
|
||||
|
||||
impl RateLimit for RateLimitA {
|
||||
fn capacity() -> usize {
|
||||
fn capacity(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn refill_duration() -> Duration {
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(2)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-a"
|
||||
}
|
||||
}
|
||||
@ -293,15 +302,15 @@ mod tests {
|
||||
struct RateLimitB;
|
||||
|
||||
impl RateLimit for RateLimitB {
|
||||
fn capacity() -> usize {
|
||||
fn capacity(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn refill_duration() -> Duration {
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(3)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-b"
|
||||
}
|
||||
}
|
||||
|
@ -199,6 +199,23 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let Some(user_id) = self.user_id() else {
|
||||
return Ok(proto::Plan::Free);
|
||||
};
|
||||
|
||||
let db = self.db().await;
|
||||
if db.has_active_billing_subscription(user_id).await? {
|
||||
Ok(proto::Plan::ZedPro)
|
||||
} else {
|
||||
Ok(proto::Plan::Free)
|
||||
}
|
||||
}
|
||||
|
||||
fn dev_server_id(&self) -> Option<DevServerId> {
|
||||
match &self.principal {
|
||||
Principal::User(_) | Principal::Impersonated { .. } => None,
|
||||
@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
||||
version.0.minor() < 139
|
||||
}
|
||||
|
||||
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
|
||||
|
||||
let plan = if session.is_staff() || !active_subscriptions.is_empty() {
|
||||
proto::Plan::ZedPro
|
||||
} else {
|
||||
proto::Plan::Free
|
||||
};
|
||||
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
|
||||
let plan = session.current_plan().await?;
|
||||
|
||||
session
|
||||
.peer
|
||||
@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CompleteWithLanguageModelRateLimit;
|
||||
struct ZedProCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"complete-with-language-model"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
@ -4562,9 +4591,14 @@ async fn complete_with_language_model(
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model(
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
@ -4684,9 +4723,14 @@ async fn count_language_model_tokens(
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
@ -4713,41 +4757,79 @@ async fn count_language_model_tokens(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountLanguageModelTokensRateLimit;
|
||||
struct ZedProCountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for CountLanguageModelTokensRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"count-language-model-tokens"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ComputeEmbeddingsRateLimit;
|
||||
struct FreeCountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for ComputeEmbeddingsRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ZedProComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"compute-embeddings"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for FreeComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
@ -4760,9 +4842,14 @@ async fn compute_embeddings(
|
||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<ComputeEmbeddingsRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let embeddings = match request.model.as_str() {
|
||||
@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
|
||||
let db = session.db().await;
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if flags.iter().any(|flag| flag == "language-models") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("permission denied"))?
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
/// Get a Supermaven API key for the user
|
||||
|
Loading…
Reference in New Issue
Block a user