diff --git a/Cargo.lock b/Cargo.lock index da2eeee4ad..838a8a4bbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2464,6 +2464,7 @@ dependencies = [ "headless", "hex", "http_client", + "hyper", "indoc", "jsonwebtoken", "language", diff --git a/Cargo.toml b/Cargo.toml index cad69682b4..6dc69324ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -340,6 +340,7 @@ git2 = { version = "0.19", default-features = false } globset = "0.4" heed = { version = "0.20.1", features = ["read-txn-no-tls"] } hex = "0.4.3" +hyper = "0.14" html5ever = "0.27.0" ignore = "0.4.22" image = "0.25.1" diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index d4406348d5..019b5833ff 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,3 +1,5 @@ +mod supported_countries; + use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; @@ -6,6 +8,8 @@ use serde::{Deserialize, Serialize}; use std::time::Duration; use strum::EnumIter; +pub use supported_countries::*; + pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] diff --git a/crates/anthropic/src/supported_countries.rs b/crates/anthropic/src/supported_countries.rs new file mode 100644 index 0000000000..a1d67791d4 --- /dev/null +++ b/crates/anthropic/src/supported_countries.rs @@ -0,0 +1,194 @@ +use std::collections::HashSet; +use std::sync::LazyLock; + +/// Returns whether the given country code is supported by Anthropic. +/// +/// https://www.anthropic.com/supported-countries +pub fn is_supported_country(country_code: &str) -> bool { + SUPPORTED_COUNTRIES.contains(&country_code) +} + +/// The list of country codes supported by Anthropic. +/// +/// https://www.anthropic.com/supported-countries +static SUPPORTED_COUNTRIES: LazyLock> = LazyLock::new(|| { + vec![ + "AL", // Albania + "DZ", // Algeria + "AD", // Andorra + "AO", // Angola + "AG", // Antigua and Barbuda + "AR", // Argentina + "AM", // Armenia + "AU", // Australia + "AT", // Austria + "AZ", // Azerbaijan + "BS", // Bahamas + "BH", // Bahrain + "BD", // Bangladesh + "BB", // Barbados + "BE", // Belgium + "BZ", // Belize + "BJ", // Benin + "BT", // Bhutan + "BO", // Bolivia + "BA", // Bosnia and Herzegovina + "BW", // Botswana + "BR", // Brazil + "BN", // Brunei + "BG", // Bulgaria + "BF", // Burkina Faso + "BI", // Burundi + "CV", // Cabo Verde + "KH", // Cambodia + "CM", // Cameroon + "CA", // Canada + "TD", // Chad + "CL", // Chile + "CO", // Colombia + "KM", // Comoros + "CG", // Congo (Brazzaville) + "CR", // Costa Rica + "CI", // Côte d'Ivoire + "HR", // Croatia + "CY", // Cyprus + "CZ", // Czechia (Czech Republic) + "DK", // Denmark + "DJ", // Djibouti + "DM", // Dominica + "DO", // Dominican Republic + "EC", // Ecuador + "EG", // Egypt + "SV", // El Salvador + "GQ", // Equatorial Guinea + "EE", // Estonia + "SZ", // Eswatini + "FJ", // Fiji + "FI", // Finland + "FR", // France + "GA", // Gabon + "GM", // Gambia + "GE", // Georgia + "DE", // Germany + "GH", // Ghana + "GR", // Greece + "GD", // Grenada + "GT", // Guatemala + "GN", // Guinea + "GW", // Guinea-Bissau + "GY", // Guyana + "HT", // Haiti + "HN", // Honduras + "HU", // Hungary + "IS", // Iceland + "IN", // India + "ID", // Indonesia + "IQ", // Iraq + "IE", // Ireland + "IL", // Israel + "IT", // Italy + "JM", // Jamaica + "JP", // Japan + "JO", // Jordan + "KZ", // Kazakhstan + "KE", // Kenya + "KI", // Kiribati + "KW", // Kuwait + "KG", // Kyrgyzstan + "LA", // Laos + "LV", // Latvia + "LB", // Lebanon + "LS", // Lesotho + "LR", // Liberia + "LI", // Liechtenstein + "LT", // Lithuania + "LU", // Luxembourg + "MG", // Madagascar + "MW", // Malawi + "MY", // Malaysia + "MV", // Maldives + "MT", // Malta + "MH", // Marshall Islands + "MR", // Mauritania + "MU", // Mauritius + "MX", // Mexico + "FM", // Micronesia + "MD", // Moldova + "MC", // Monaco + "MN", // Mongolia + "ME", // Montenegro + "MA", // Morocco + "MZ", // Mozambique + "NA", // Namibia + "NR", // Nauru + "NP", // Nepal + "NL", // Netherlands + "NZ", // New Zealand + "NE", // Niger + "NG", // Nigeria + "MK", // North Macedonia + "NO", // Norway + "OM", // Oman + "PK", // Pakistan + "PW", // Palau + "PS", // Palestine + "PA", // Panama + "PG", // Papua New Guinea + "PY", // Paraguay + "PE", // Peru + "PH", // Philippines + "PL", // Poland + "PT", // Portugal + "QA", // Qatar + "RO", // Romania + "RW", // Rwanda + "KN", // Saint Kitts and Nevis + "LC", // Saint Lucia + "VC", // Saint Vincent and the Grenadines + "WS", // Samoa + "SM", // San Marino + "ST", // São Tomé and Príncipe + "SA", // Saudi Arabia + "SN", // Senegal + "RS", // Serbia + "SC", // Seychelles + "SL", // Sierra Leone + "SG", // Singapore + "SK", // Slovakia + "SI", // Slovenia + "SB", // Solomon Islands + "ZA", // South Africa + "KR", // South Korea + "ES", // Spain + "LK", // Sri Lanka + "SR", // Suriname + "SE", // Sweden + "CH", // Switzerland + "TW", // Taiwan + "TJ", // Tajikistan + "TZ", // Tanzania + "TH", // Thailand + "TL", // Timor-Leste + "TG", // Togo + "TO", // Tonga + "TT", // Trinidad and Tobago + "TN", // Tunisia + "TR", // Türkiye (Turkey) + "TM", // Turkmenistan + "TV", // Tuvalu + "UG", // Uganda + "UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions) + "AE", // United Arab Emirates + "GB", // United Kingdom + "US", // United States of America + "UY", // Uruguay + "UZ", // Uzbekistan + "VU", // Vanuatu + "VA", // Vatican City + "VN", // Vietnam + "ZM", // Zambia + "ZW", // Zimbabwe + ] + .into_iter() + .collect() +}); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 2b6583f970..19d04ac92e 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -90,6 +90,7 @@ fs = { workspace = true, features = ["test-support"] } git = { workspace = true, features = ["test-support"] } git_hosting_providers.workspace = true gpui = { workspace = true, features = ["test-support"] } +hyper.workspace = true indoc.workspace = true language = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index d88f37354e..a265e11dda 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -185,6 +185,46 @@ impl Config { _ => "https://zed.dev", } } + + #[cfg(test)] + pub fn test() -> Self { + Self { + http_port: 0, + database_url: "".into(), + database_max_connections: 0, + api_token: "".into(), + invite_link_prefix: "".into(), + live_kit_server: None, + live_kit_key: None, + live_kit_secret: None, + llm_api_secret: None, + rust_log: None, + log_json: None, + zed_environment: "test".into(), + blob_store_url: None, + blob_store_region: None, + blob_store_access_key: None, + blob_store_secret_key: None, + blob_store_bucket: None, + openai_api_key: None, + google_ai_api_key: None, + anthropic_api_key: None, + clickhouse_url: None, + clickhouse_user: None, + clickhouse_password: None, + clickhouse_database: None, + zed_client_checksum_seed: None, + slack_panics_webhook: None, + auto_join_channel_id: None, + migrations_path: None, + seed_path: None, + stripe_api_key: None, + stripe_price_id: None, + supermaven_admin_api_key: None, + qwen2_7b_api_key: None, + qwen2_7b_api_url: None, + } + } } /// The service mode that collab should run in. diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 1c5cf8625b..bde9f87e12 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,7 +1,11 @@ +mod authorization; mod token; +use crate::api::CloudflareIpCountryHeader; +use crate::llm::authorization::authorize_access_to_language_model; use crate::{executor::Executor, Config, Error, Result}; use anyhow::Context as _; +use axum::TypedHeader; use axum::{ body::Body, http::{self, HeaderName, HeaderValue, Request, StatusCode}, @@ -91,9 +95,18 @@ async fn validate_api_token(mut req: Request, next: Next) -> impl IntoR async fn perform_completion( Extension(state): Extension>, - Extension(_claims): Extension, + Extension(claims): Extension, + country_code_header: Option>, Json(params): Json, ) -> Result { + authorize_access_to_language_model( + &state.config, + &claims, + country_code_header.map(|header| header.to_string()), + params.provider, + ¶ms.model, + )?; + match params.provider { LanguageModelProvider::Anthropic => { let api_key = state diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs new file mode 100644 index 0000000000..8a9945d739 --- /dev/null +++ b/crates/collab/src/llm/authorization.rs @@ -0,0 +1,213 @@ +use reqwest::StatusCode; +use rpc::LanguageModelProvider; + +use crate::llm::LlmTokenClaims; +use crate::{Config, Error, Result}; + +pub fn authorize_access_to_language_model( + config: &Config, + _claims: &LlmTokenClaims, + country_code: Option, + provider: LanguageModelProvider, + model: &str, +) -> Result<()> { + authorize_access_for_country(config, country_code, provider, model)?; + + Ok(()) +} + +fn authorize_access_for_country( + config: &Config, + country_code: Option, + provider: LanguageModelProvider, + _model: &str, +) -> Result<()> { + // In development we won't have the `CF-IPCountry` header, so we can't check + // the country code. + // + // This shouldn't be necessary, as anyone running in development will need to provide + // their own API credentials in order to use an LLM provider. + if config.is_development() { + return Ok(()); + } + + // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry + let country_code = match country_code.as_deref() { + // `XX` - Used for clients without country code data. + None | Some("XX") => Err(Error::http( + StatusCode::BAD_REQUEST, + "no country code".to_string(), + ))?, + // `T1` - Used for clients using the Tor network. + Some("T1") => Err(Error::http( + StatusCode::FORBIDDEN, + format!("access to {provider:?} models is not available over Tor"), + ))?, + Some(country_code) => country_code, + }; + + let is_country_supported_by_provider = match provider { + LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code), + LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code), + LanguageModelProvider::Google => google_ai::is_supported_country(country_code), + LanguageModelProvider::Zed => true, + }; + if !is_country_supported_by_provider { + Err(Error::http( + StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, + format!("access to {provider:?} models is not available in your region"), + ))? + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use axum::response::IntoResponse; + use pretty_assertions::assert_eq; + use rpc::proto::Plan; + + use super::*; + + #[gpui::test] + async fn test_authorize_access_to_language_model_with_supported_country( + _cx: &mut gpui::TestAppContext, + ) { + let config = Config::test(); + + let claims = LlmTokenClaims { + user_id: 99, + plan: Plan::ZedPro, + ..Default::default() + }; + + let cases = vec![ + (LanguageModelProvider::Anthropic, "US"), // United States + (LanguageModelProvider::Anthropic, "GB"), // United Kingdom + (LanguageModelProvider::OpenAi, "US"), // United States + (LanguageModelProvider::OpenAi, "GB"), // United Kingdom + (LanguageModelProvider::Google, "US"), // United States + (LanguageModelProvider::Google, "GB"), // United Kingdom + ]; + + for (provider, country_code) in cases { + authorize_access_to_language_model( + &config, + &claims, + Some(country_code.into()), + provider, + "the-model", + ) + .unwrap_or_else(|_| { + panic!("expected authorization to return Ok for {provider:?}: {country_code}") + }) + } + } + + #[gpui::test] + async fn test_authorize_access_to_language_model_with_unsupported_country( + _cx: &mut gpui::TestAppContext, + ) { + let config = Config::test(); + + let claims = LlmTokenClaims { + user_id: 99, + plan: Plan::ZedPro, + ..Default::default() + }; + + let cases = vec![ + (LanguageModelProvider::Anthropic, "AF"), // Afghanistan + (LanguageModelProvider::Anthropic, "BY"), // Belarus + (LanguageModelProvider::Anthropic, "CF"), // Central African Republic + (LanguageModelProvider::Anthropic, "CN"), // China + (LanguageModelProvider::Anthropic, "CU"), // Cuba + (LanguageModelProvider::Anthropic, "ER"), // Eritrea + (LanguageModelProvider::Anthropic, "ET"), // Ethiopia + (LanguageModelProvider::Anthropic, "IR"), // Iran + (LanguageModelProvider::Anthropic, "KP"), // North Korea + (LanguageModelProvider::Anthropic, "XK"), // Kosovo + (LanguageModelProvider::Anthropic, "LY"), // Libya + (LanguageModelProvider::Anthropic, "MM"), // Myanmar + (LanguageModelProvider::Anthropic, "RU"), // Russia + (LanguageModelProvider::Anthropic, "SO"), // Somalia + (LanguageModelProvider::Anthropic, "SS"), // South Sudan + (LanguageModelProvider::Anthropic, "SD"), // Sudan + (LanguageModelProvider::Anthropic, "SY"), // Syria + (LanguageModelProvider::Anthropic, "VE"), // Venezuela + (LanguageModelProvider::Anthropic, "YE"), // Yemen + (LanguageModelProvider::OpenAi, "KP"), // North Korea + (LanguageModelProvider::Google, "KP"), // North Korea + ]; + + for (provider, country_code) in cases { + let error_response = authorize_access_to_language_model( + &config, + &claims, + Some(country_code.into()), + provider, + "the-model", + ) + .expect_err(&format!( + "expected authorization to return an error for {provider:?}: {country_code}" + )) + .into_response(); + + assert_eq!( + error_response.status(), + StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS + ); + let response_body = hyper::body::to_bytes(error_response.into_body()) + .await + .unwrap() + .to_vec(); + assert_eq!( + String::from_utf8(response_body).unwrap(), + format!("access to {provider:?} models is not available in your region") + ); + } + } + + #[gpui::test] + async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) { + let config = Config::test(); + + let claims = LlmTokenClaims { + user_id: 99, + plan: Plan::ZedPro, + ..Default::default() + }; + + let cases = vec![ + (LanguageModelProvider::Anthropic, "T1"), // Tor + (LanguageModelProvider::OpenAi, "T1"), // Tor + (LanguageModelProvider::Google, "T1"), // Tor + (LanguageModelProvider::Zed, "T1"), // Tor + ]; + + for (provider, country_code) in cases { + let error_response = authorize_access_to_language_model( + &config, + &claims, + Some(country_code.into()), + provider, + "the-model", + ) + .expect_err(&format!( + "expected authorization to return an error for {provider:?}: {country_code}" + )) + .into_response(); + + assert_eq!(error_response.status(), StatusCode::FORBIDDEN); + let response_body = hyper::body::to_bytes(error_response.into_body()) + .await + .unwrap() + .to_vec(); + assert_eq!( + String::from_utf8(response_body).unwrap(), + format!("access to {provider:?} models is not available over Tor") + ); + } + } +} diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index b2ecf33243..631a6b20ca 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,8 +1,12 @@ +mod supported_countries; + use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::HttpClient; use serde::{Deserialize, Serialize}; +pub use supported_countries::*; + pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( diff --git a/crates/google_ai/src/supported_countries.rs b/crates/google_ai/src/supported_countries.rs new file mode 100644 index 0000000000..231b99d82a --- /dev/null +++ b/crates/google_ai/src/supported_countries.rs @@ -0,0 +1,232 @@ +use std::collections::HashSet; +use std::sync::LazyLock; + +/// Returns whether the given country code is supported by Google Gemini. +/// +/// https://ai.google.dev/gemini-api/docs/available-regions +pub fn is_supported_country(country_code: &str) -> bool { + SUPPORTED_COUNTRIES.contains(&country_code) +} + +/// The list of country codes supported by Google Gemini. +/// +/// https://ai.google.dev/gemini-api/docs/available-regions +static SUPPORTED_COUNTRIES: LazyLock> = LazyLock::new(|| { + vec![ + "DZ", // Algeria + "AS", // American Samoa + "AO", // Angola + "AI", // Anguilla + "AQ", // Antarctica + "AG", // Antigua and Barbuda + "AR", // Argentina + "AM", // Armenia + "AW", // Aruba + "AU", // Australia + "AT", // Austria + "AZ", // Azerbaijan + "BS", // The Bahamas + "BH", // Bahrain + "BD", // Bangladesh + "BB", // Barbados + "BE", // Belgium + "BZ", // Belize + "BJ", // Benin + "BM", // Bermuda + "BT", // Bhutan + "BO", // Bolivia + "BW", // Botswana + "BR", // Brazil + "IO", // British Indian Ocean Territory + "VG", // British Virgin Islands + "BN", // Brunei + "BG", // Bulgaria + "BF", // Burkina Faso + "BI", // Burundi + "CV", // Cabo Verde + "KH", // Cambodia + "CM", // Cameroon + "CA", // Canada + "BQ", // Caribbean Netherlands + "KY", // Cayman Islands + "CF", // Central African Republic + "TD", // Chad + "CL", // Chile + "CX", // Christmas Island + "CC", // Cocos (Keeling) Islands + "CO", // Colombia + "KM", // Comoros + "CK", // Cook Islands + "CI", // Côte d'Ivoire + "CR", // Costa Rica + "HR", // Croatia + "CW", // Curaçao + "CZ", // Czech Republic + "CD", // Democratic Republic of the Congo + "DK", // Denmark + "DJ", // Djibouti + "DM", // Dominica + "DO", // Dominican Republic + "EC", // Ecuador + "EG", // Egypt + "SV", // El Salvador + "GQ", // Equatorial Guinea + "ER", // Eritrea + "EE", // Estonia + "SZ", // Eswatini + "ET", // Ethiopia + "FK", // Falkland Islands (Islas Malvinas) + "FJ", // Fiji + "FI", // Finland + "FR", // France + "GA", // Gabon + "GM", // The Gambia + "GE", // Georgia + "DE", // Germany + "GH", // Ghana + "GI", // Gibraltar + "GR", // Greece + "GD", // Grenada + "GU", // Guam + "GT", // Guatemala + "GG", // Guernsey + "GN", // Guinea + "GW", // Guinea-Bissau + "GY", // Guyana + "HT", // Haiti + "HM", // Heard Island and McDonald Islands + "HN", // Honduras + "HU", // Hungary + "IS", // Iceland + "IN", // India + "ID", // Indonesia + "IQ", // Iraq + "IE", // Ireland + "IM", // Isle of Man + "IL", // Israel + "IT", // Italy + "JM", // Jamaica + "JP", // Japan + "JE", // Jersey + "JO", // Jordan + "KZ", // Kazakhstan + "KE", // Kenya + "KI", // Kiribati + "KG", // Kyrgyzstan + "KW", // Kuwait + "LA", // Laos + "LV", // Latvia + "LB", // Lebanon + "LS", // Lesotho + "LR", // Liberia + "LY", // Libya + "LI", // Liechtenstein + "LT", // Lithuania + "LU", // Luxembourg + "MG", // Madagascar + "MW", // Malawi + "MY", // Malaysia + "MV", // Maldives + "ML", // Mali + "MT", // Malta + "MH", // Marshall Islands + "MR", // Mauritania + "MU", // Mauritius + "MX", // Mexico + "FM", // Micronesia + "MN", // Mongolia + "MS", // Montserrat + "MA", // Morocco + "MZ", // Mozambique + "NA", // Namibia + "NR", // Nauru + "NP", // Nepal + "NL", // Netherlands + "NC", // New Caledonia + "NZ", // New Zealand + "NI", // Nicaragua + "NE", // Niger + "NG", // Nigeria + "NU", // Niue + "NF", // Norfolk Island + "MP", // Northern Mariana Islands + "NO", // Norway + "OM", // Oman + "PK", // Pakistan + "PW", // Palau + "PS", // Palestine + "PA", // Panama + "PG", // Papua New Guinea + "PY", // Paraguay + "PE", // Peru + "PH", // Philippines + "PN", // Pitcairn Islands + "PL", // Poland + "PT", // Portugal + "PR", // Puerto Rico + "QA", // Qatar + "CY", // Republic of Cyprus + "CG", // Republic of the Congo + "RO", // Romania + "RW", // Rwanda + "BL", // Saint Barthélemy + "KN", // Saint Kitts and Nevis + "LC", // Saint Lucia + "PM", // Saint Pierre and Miquelon + "VC", // Saint Vincent and the Grenadines + "SH", // Saint Helena, Ascension and Tristan da Cunha + "WS", // Samoa + "ST", // São Tomé and Príncipe + "SA", // Saudi Arabia + "SN", // Senegal + "SC", // Seychelles + "SL", // Sierra Leone + "SG", // Singapore + "SK", // Slovakia + "SI", // Slovenia + "SB", // Solomon Islands + "SO", // Somalia + "ZA", // South Africa + "GS", // South Georgia and the South Sandwich Islands + "KR", // South Korea + "SS", // South Sudan + "ES", // Spain + "LK", // Sri Lanka + "SD", // Sudan + "SR", // Suriname + "SE", // Sweden + "CH", // Switzerland + "TW", // Taiwan + "TJ", // Tajikistan + "TZ", // Tanzania + "TH", // Thailand + "TL", // Timor-Leste + "TG", // Togo + "TK", // Tokelau + "TO", // Tonga + "TT", // Trinidad and Tobago + "TN", // Tunisia + "TR", // Türkiye + "TM", // Turkmenistan + "TC", // Turks and Caicos Islands + "TV", // Tuvalu + "UG", // Uganda + "GB", // United Kingdom + "AE", // United Arab Emirates + "US", // United States + "UM", // United States Minor Outlying Islands + "VI", // U.S. Virgin Islands + "UY", // Uruguay + "UZ", // Uzbekistan + "VU", // Vanuatu + "VE", // Venezuela + "VN", // Vietnam + "WF", // Wallis and Futuna + "EH", // Western Sahara + "YE", // Yemen + "ZM", // Zambia + "ZW", // Zimbabwe + ] + .into_iter() + .collect() +}); diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 4fb62831b6..7ef6d1413a 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,3 +1,5 @@ +mod supported_countries; + use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; @@ -7,6 +9,8 @@ use serde_json::Value; use std::{convert::TryFrom, future::Future, time::Duration}; use strum::EnumIter; +pub use supported_countries::*; + pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; fn is_none_or_empty, U>(opt: &Option) -> bool { diff --git a/crates/open_ai/src/supported_countries.rs b/crates/open_ai/src/supported_countries.rs new file mode 100644 index 0000000000..4b28694023 --- /dev/null +++ b/crates/open_ai/src/supported_countries.rs @@ -0,0 +1,207 @@ +use std::collections::HashSet; +use std::sync::LazyLock; + +/// Returns whether the given country code is supported by OpenAI. +/// +/// https://platform.openai.com/docs/supported-countries +pub fn is_supported_country(country_code: &str) -> bool { + SUPPORTED_COUNTRIES.contains(&country_code) +} + +/// The list of country codes supported by OpenAI. +/// +/// https://platform.openai.com/docs/supported-countries +static SUPPORTED_COUNTRIES: LazyLock> = LazyLock::new(|| { + vec![ + "AL", // Albania + "DZ", // Algeria + "AF", // Afghanistan + "AD", // Andorra + "AO", // Angola + "AG", // Antigua and Barbuda + "AR", // Argentina + "AM", // Armenia + "AU", // Australia + "AT", // Austria + "AZ", // Azerbaijan + "BS", // Bahamas + "BH", // Bahrain + "BD", // Bangladesh + "BB", // Barbados + "BE", // Belgium + "BZ", // Belize + "BJ", // Benin + "BT", // Bhutan + "BO", // Bolivia + "BA", // Bosnia and Herzegovina + "BW", // Botswana + "BR", // Brazil + "BN", // Brunei + "BG", // Bulgaria + "BF", // Burkina Faso + "BI", // Burundi + "CV", // Cabo Verde + "KH", // Cambodia + "CM", // Cameroon + "CA", // Canada + "CF", // Central African Republic + "TD", // Chad + "CL", // Chile + "CO", // Colombia + "KM", // Comoros + "CG", // Congo (Brazzaville) + "CD", // Congo (DRC) + "CR", // Costa Rica + "CI", // Côte d'Ivoire + "HR", // Croatia + "CY", // Cyprus + "CZ", // Czechia (Czech Republic) + "DK", // Denmark + "DJ", // Djibouti + "DM", // Dominica + "DO", // Dominican Republic + "EC", // Ecuador + "EG", // Egypt + "SV", // El Salvador + "GQ", // Equatorial Guinea + "ER", // Eritrea + "EE", // Estonia + "SZ", // Eswatini (Swaziland) + "ET", // Ethiopia + "FJ", // Fiji + "FI", // Finland + "FR", // France + "GA", // Gabon + "GM", // Gambia + "GE", // Georgia + "DE", // Germany + "GH", // Ghana + "GR", // Greece + "GD", // Grenada + "GT", // Guatemala + "GN", // Guinea + "GW", // Guinea-Bissau + "GY", // Guyana + "HT", // Haiti + "VA", // Holy See (Vatican City) + "HN", // Honduras + "HU", // Hungary + "IS", // Iceland + "IN", // India + "ID", // Indonesia + "IQ", // Iraq + "IE", // Ireland + "IL", // Israel + "IT", // Italy + "JM", // Jamaica + "JP", // Japan + "JO", // Jordan + "KZ", // Kazakhstan + "KE", // Kenya + "KI", // Kiribati + "KW", // Kuwait + "KG", // Kyrgyzstan + "LA", // Laos + "LV", // Latvia + "LB", // Lebanon + "LS", // Lesotho + "LR", // Liberia + "LY", // Libya + "LI", // Liechtenstein + "LT", // Lithuania + "LU", // Luxembourg + "MG", // Madagascar + "MW", // Malawi + "MY", // Malaysia + "MV", // Maldives + "ML", // Mali + "MT", // Malta + "MH", // Marshall Islands + "MR", // Mauritania + "MU", // Mauritius + "MX", // Mexico + "FM", // Micronesia + "MD", // Moldova + "MC", // Monaco + "MN", // Mongolia + "ME", // Montenegro + "MA", // Morocco + "MZ", // Mozambique + "MM", // Myanmar + "NA", // Namibia + "NR", // Nauru + "NP", // Nepal + "NL", // Netherlands + "NZ", // New Zealand + "NI", // Nicaragua + "NE", // Niger + "NG", // Nigeria + "MK", // North Macedonia + "NO", // Norway + "OM", // Oman + "PK", // Pakistan + "PW", // Palau + "PS", // Palestine + "PA", // Panama + "PG", // Papua New Guinea + "PY", // Paraguay + "PE", // Peru + "PH", // Philippines + "PL", // Poland + "PT", // Portugal + "QA", // Qatar + "RO", // Romania + "RW", // Rwanda + "KN", // Saint Kitts and Nevis + "LC", // Saint Lucia + "VC", // Saint Vincent and the Grenadines + "WS", // Samoa + "SM", // San Marino + "ST", // Sao Tome and Principe + "SA", // Saudi Arabia + "SN", // Senegal + "RS", // Serbia + "SC", // Seychelles + "SL", // Sierra Leone + "SG", // Singapore + "SK", // Slovakia + "SI", // Slovenia + "SB", // Solomon Islands + "SO", // Somalia + "ZA", // South Africa + "KR", // South Korea + "SS", // South Sudan + "ES", // Spain + "LK", // Sri Lanka + "SR", // Suriname + "SE", // Sweden + "CH", // Switzerland + "SD", // Sudan + "TW", // Taiwan + "TJ", // Tajikistan + "TZ", // Tanzania + "TH", // Thailand + "TL", // Timor-Leste (East Timor) + "TG", // Togo + "TO", // Tonga + "TT", // Trinidad and Tobago + "TN", // Tunisia + "TR", // Turkey + "TM", // Turkmenistan + "TV", // Tuvalu + "UG", // Uganda + "UA", // Ukraine (with certain exceptions) + "AE", // United Arab Emirates + "GB", // United Kingdom + "US", // United States of America + "UY", // Uruguay + "UZ", // Uzbekistan + "VU", // Vanuatu + "VN", // Vietnam + "YE", // Yemen + "ZM", // Zambia + "ZW", // Zimbabwe + ] + .into_iter() + .collect() +}); diff --git a/typos.toml b/typos.toml index 8927f9eea9..732a9c79dd 100644 --- a/typos.toml +++ b/typos.toml @@ -6,6 +6,12 @@ extend-exclude = [ # File suffixes aren't typos "assets/icons/file_icons/file_types.json", "crates/extensions_ui/src/extension_suggest.rs", + + # Some countries codes are flagged as typos. + "crates/anthropic/src/supported_countries.rs", + "crates/google_ai/src/supported_countries.rs", + "crates/open_ai/src/supported_countries.rs", + # Stripe IDs are flagged as typos. "crates/collab/src/db/tests/processed_stripe_event_tests.rs", # Not our typos