Authorize access to language model providers based on country (#15859)

This PR updates the LLM service to authorize access to language model
providers based on the requester's country.

We detect the country using Cloudflare's
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

The country code is then checked against the list of supported countries
for the given LLM provider. Countries that are not supported will
receive an `HTTP 451: Unavailable For Legal Reasons` response.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-06 11:49:04 -04:00 committed by GitHub
parent 9c6ccaffe3
commit cf5f4dddf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 921 additions and 1 deletions

1
Cargo.lock generated
View File

@ -2464,6 +2464,7 @@ dependencies = [
"headless", "headless",
"hex", "hex",
"http_client", "http_client",
"hyper",
"indoc", "indoc",
"jsonwebtoken", "jsonwebtoken",
"language", "language",

View File

@ -340,6 +340,7 @@ git2 = { version = "0.19", default-features = false }
globset = "0.4" globset = "0.4"
heed = { version = "0.20.1", features = ["read-txn-no-tls"] } heed = { version = "0.20.1", features = ["read-txn-no-tls"] }
hex = "0.4.3" hex = "0.4.3"
hyper = "0.14"
html5ever = "0.27.0" html5ever = "0.27.0"
ignore = "0.4.22" ignore = "0.4.22"
image = "0.25.1" image = "0.25.1"

View File

@ -1,3 +1,5 @@
mod supported_countries;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@ -6,6 +8,8 @@ use serde::{Deserialize, Serialize};
use std::time::Duration; use std::time::Duration;
use strum::EnumIter; use strum::EnumIter;
pub use supported_countries::*;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]

View File

@ -0,0 +1,194 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by Anthropic.
///
/// https://www.anthropic.com/supported-countries
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by Anthropic.
///
/// https://www.anthropic.com/supported-countries
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"AL", // Albania
"DZ", // Algeria
"AD", // Andorra
"AO", // Angola
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BT", // Bhutan
"BO", // Bolivia
"BA", // Bosnia and Herzegovina
"BW", // Botswana
"BR", // Brazil
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"TD", // Chad
"CL", // Chile
"CO", // Colombia
"KM", // Comoros
"CG", // Congo (Brazzaville)
"CR", // Costa Rica
"CI", // Côte d'Ivoire
"HR", // Croatia
"CY", // Cyprus
"CZ", // Czechia (Czech Republic)
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"EE", // Estonia
"SZ", // Eswatini
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GA", // Gabon
"GM", // Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GR", // Greece
"GD", // Grenada
"GT", // Guatemala
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KW", // Kuwait
"KG", // Kyrgyzstan
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MD", // Moldova
"MC", // Monaco
"MN", // Mongolia
"ME", // Montenegro
"MA", // Morocco
"MZ", // Mozambique
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NZ", // New Zealand
"NE", // Niger
"NG", // Nigeria
"MK", // North Macedonia
"NO", // Norway
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PL", // Poland
"PT", // Portugal
"QA", // Qatar
"RO", // Romania
"RW", // Rwanda
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"VC", // Saint Vincent and the Grenadines
"WS", // Samoa
"SM", // San Marino
"ST", // São Tomé and Príncipe
"SA", // Saudi Arabia
"SN", // Senegal
"RS", // Serbia
"SC", // Seychelles
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"ZA", // South Africa
"KR", // South Korea
"ES", // Spain
"LK", // Sri Lanka
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste
"TG", // Togo
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Türkiye (Turkey)
"TM", // Turkmenistan
"TV", // Tuvalu
"UG", // Uganda
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
"AE", // United Arab Emirates
"GB", // United Kingdom
"US", // United States of America
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VA", // Vatican City
"VN", // Vietnam
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@ -90,6 +90,7 @@ fs = { workspace = true, features = ["test-support"] }
git = { workspace = true, features = ["test-support"] } git = { workspace = true, features = ["test-support"] }
git_hosting_providers.workspace = true git_hosting_providers.workspace = true
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
hyper.workspace = true
indoc.workspace = true indoc.workspace = true
language = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] }

View File

@ -185,6 +185,46 @@ impl Config {
_ => "https://zed.dev", _ => "https://zed.dev",
} }
} }
#[cfg(test)]
pub fn test() -> Self {
Self {
http_port: 0,
database_url: "".into(),
database_max_connections: 0,
api_token: "".into(),
invite_link_prefix: "".into(),
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
zed_environment: "test".into(),
blob_store_url: None,
blob_store_region: None,
blob_store_access_key: None,
blob_store_secret_key: None,
blob_store_bucket: None,
openai_api_key: None,
google_ai_api_key: None,
anthropic_api_key: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
clickhouse_database: None,
zed_client_checksum_seed: None,
slack_panics_webhook: None,
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
qwen2_7b_api_key: None,
qwen2_7b_api_url: None,
}
}
} }
/// The service mode that collab should run in. /// The service mode that collab should run in.

View File

@ -1,7 +1,11 @@
mod authorization;
mod token; mod token;
use crate::api::CloudflareIpCountryHeader;
use crate::llm::authorization::authorize_access_to_language_model;
use crate::{executor::Executor, Config, Error, Result}; use crate::{executor::Executor, Config, Error, Result};
use anyhow::Context as _; use anyhow::Context as _;
use axum::TypedHeader;
use axum::{ use axum::{
body::Body, body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode}, http::{self, HeaderName, HeaderValue, Request, StatusCode},
@ -91,9 +95,18 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
async fn perform_completion( async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>, Extension(state): Extension<Arc<LlmState>>,
Extension(_claims): Extension<LlmTokenClaims>, Extension(claims): Extension<LlmTokenClaims>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>, Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
authorize_access_to_language_model(
&state.config,
&claims,
country_code_header.map(|header| header.to_string()),
params.provider,
&params.model,
)?;
match params.provider { match params.provider {
LanguageModelProvider::Anthropic => { LanguageModelProvider::Anthropic => {
let api_key = state let api_key = state

View File

@ -0,0 +1,213 @@
use reqwest::StatusCode;
use rpc::LanguageModelProvider;
use crate::llm::LlmTokenClaims;
use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
_claims: &LlmTokenClaims,
country_code: Option<String>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider, model)?;
Ok(())
}
fn authorize_access_for_country(
config: &Config,
country_code: Option<String>,
provider: LanguageModelProvider,
_model: &str,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
//
// This shouldn't be necessary, as anyone running in development will need to provide
// their own API credentials in order to use an LLM provider.
if config.is_development() {
return Ok(());
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
let country_code = match country_code.as_deref() {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
"no country code".to_string(),
))?,
// `T1` - Used for clients using the Tor network.
Some("T1") => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to {provider:?} models is not available over Tor"),
))?,
Some(country_code) => country_code,
};
let is_country_supported_by_provider = match provider {
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
LanguageModelProvider::Zed => true,
};
if !is_country_supported_by_provider {
Err(Error::http(
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
format!("access to {provider:?} models is not available in your region"),
))?
}
Ok(())
}
#[cfg(test)]
mod tests {
use axum::response::IntoResponse;
use pretty_assertions::assert_eq;
use rpc::proto::Plan;
use super::*;
#[gpui::test]
async fn test_authorize_access_to_language_model_with_supported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "US"), // United States
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
(LanguageModelProvider::OpenAi, "US"), // United States
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
(LanguageModelProvider::Google, "US"), // United States
(LanguageModelProvider::Google, "GB"), // United Kingdom
];
for (provider, country_code) in cases {
authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
provider,
"the-model",
)
.unwrap_or_else(|_| {
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
})
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_unsupported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
(LanguageModelProvider::Anthropic, "BY"), // Belarus
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
(LanguageModelProvider::Anthropic, "CN"), // China
(LanguageModelProvider::Anthropic, "CU"), // Cuba
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
(LanguageModelProvider::Anthropic, "IR"), // Iran
(LanguageModelProvider::Anthropic, "KP"), // North Korea
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
(LanguageModelProvider::Anthropic, "LY"), // Libya
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
(LanguageModelProvider::Anthropic, "RU"), // Russia
(LanguageModelProvider::Anthropic, "SO"), // Somalia
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
(LanguageModelProvider::Anthropic, "SD"), // Sudan
(LanguageModelProvider::Anthropic, "SY"), // Syria
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
(LanguageModelProvider::Anthropic, "YE"), // Yemen
(LanguageModelProvider::OpenAi, "KP"), // North Korea
(LanguageModelProvider::Google, "KP"), // North Korea
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(
error_response.status(),
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!("access to {provider:?} models is not available in your region")
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "T1"), // Tor
(LanguageModelProvider::OpenAi, "T1"), // Tor
(LanguageModelProvider::Google, "T1"), // Tor
(LanguageModelProvider::Zed, "T1"), // Tor
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!("access to {provider:?} models is not available over Tor")
);
}
}
}

View File

@ -1,8 +1,12 @@
mod supported_countries;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::HttpClient; use http_client::HttpClient;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use supported_countries::*;
pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content( pub async fn stream_generate_content(

View File

@ -0,0 +1,232 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by Google Gemini.
///
/// https://ai.google.dev/gemini-api/docs/available-regions
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by Google Gemini.
///
/// https://ai.google.dev/gemini-api/docs/available-regions
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"DZ", // Algeria
"AS", // American Samoa
"AO", // Angola
"AI", // Anguilla
"AQ", // Antarctica
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AW", // Aruba
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // The Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BM", // Bermuda
"BT", // Bhutan
"BO", // Bolivia
"BW", // Botswana
"BR", // Brazil
"IO", // British Indian Ocean Territory
"VG", // British Virgin Islands
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"BQ", // Caribbean Netherlands
"KY", // Cayman Islands
"CF", // Central African Republic
"TD", // Chad
"CL", // Chile
"CX", // Christmas Island
"CC", // Cocos (Keeling) Islands
"CO", // Colombia
"KM", // Comoros
"CK", // Cook Islands
"CI", // Côte d'Ivoire
"CR", // Costa Rica
"HR", // Croatia
"CW", // Curaçao
"CZ", // Czech Republic
"CD", // Democratic Republic of the Congo
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"ER", // Eritrea
"EE", // Estonia
"SZ", // Eswatini
"ET", // Ethiopia
"FK", // Falkland Islands (Islas Malvinas)
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GA", // Gabon
"GM", // The Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GI", // Gibraltar
"GR", // Greece
"GD", // Grenada
"GU", // Guam
"GT", // Guatemala
"GG", // Guernsey
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"HM", // Heard Island and McDonald Islands
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IM", // Isle of Man
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JE", // Jersey
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KG", // Kyrgyzstan
"KW", // Kuwait
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LY", // Libya
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"ML", // Mali
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MN", // Mongolia
"MS", // Montserrat
"MA", // Morocco
"MZ", // Mozambique
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NC", // New Caledonia
"NZ", // New Zealand
"NI", // Nicaragua
"NE", // Niger
"NG", // Nigeria
"NU", // Niue
"NF", // Norfolk Island
"MP", // Northern Mariana Islands
"NO", // Norway
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PN", // Pitcairn Islands
"PL", // Poland
"PT", // Portugal
"PR", // Puerto Rico
"QA", // Qatar
"CY", // Republic of Cyprus
"CG", // Republic of the Congo
"RO", // Romania
"RW", // Rwanda
"BL", // Saint Barthélemy
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"PM", // Saint Pierre and Miquelon
"VC", // Saint Vincent and the Grenadines
"SH", // Saint Helena, Ascension and Tristan da Cunha
"WS", // Samoa
"ST", // São Tomé and Príncipe
"SA", // Saudi Arabia
"SN", // Senegal
"SC", // Seychelles
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"SO", // Somalia
"ZA", // South Africa
"GS", // South Georgia and the South Sandwich Islands
"KR", // South Korea
"SS", // South Sudan
"ES", // Spain
"LK", // Sri Lanka
"SD", // Sudan
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste
"TG", // Togo
"TK", // Tokelau
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Türkiye
"TM", // Turkmenistan
"TC", // Turks and Caicos Islands
"TV", // Tuvalu
"UG", // Uganda
"GB", // United Kingdom
"AE", // United Arab Emirates
"US", // United States
"UM", // United States Minor Outlying Islands
"VI", // U.S. Virgin Islands
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VE", // Venezuela
"VN", // Vietnam
"WF", // Wallis and Futuna
"EH", // Western Sahara
"YE", // Yemen
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@ -1,3 +1,5 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@ -7,6 +9,8 @@ use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration}; use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter; use strum::EnumIter;
pub use supported_countries::*;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool { fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {

View File

@ -0,0 +1,207 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by OpenAI.
///
/// https://platform.openai.com/docs/supported-countries
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by OpenAI.
///
/// https://platform.openai.com/docs/supported-countries
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"AL", // Albania
"DZ", // Algeria
"AF", // Afghanistan
"AD", // Andorra
"AO", // Angola
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BT", // Bhutan
"BO", // Bolivia
"BA", // Bosnia and Herzegovina
"BW", // Botswana
"BR", // Brazil
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"CF", // Central African Republic
"TD", // Chad
"CL", // Chile
"CO", // Colombia
"KM", // Comoros
"CG", // Congo (Brazzaville)
"CD", // Congo (DRC)
"CR", // Costa Rica
"CI", // Côte d'Ivoire
"HR", // Croatia
"CY", // Cyprus
"CZ", // Czechia (Czech Republic)
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"ER", // Eritrea
"EE", // Estonia
"SZ", // Eswatini (Swaziland)
"ET", // Ethiopia
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GA", // Gabon
"GM", // Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GR", // Greece
"GD", // Grenada
"GT", // Guatemala
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"VA", // Holy See (Vatican City)
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KW", // Kuwait
"KG", // Kyrgyzstan
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LY", // Libya
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"ML", // Mali
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MD", // Moldova
"MC", // Monaco
"MN", // Mongolia
"ME", // Montenegro
"MA", // Morocco
"MZ", // Mozambique
"MM", // Myanmar
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NZ", // New Zealand
"NI", // Nicaragua
"NE", // Niger
"NG", // Nigeria
"MK", // North Macedonia
"NO", // Norway
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PL", // Poland
"PT", // Portugal
"QA", // Qatar
"RO", // Romania
"RW", // Rwanda
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"VC", // Saint Vincent and the Grenadines
"WS", // Samoa
"SM", // San Marino
"ST", // Sao Tome and Principe
"SA", // Saudi Arabia
"SN", // Senegal
"RS", // Serbia
"SC", // Seychelles
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"SO", // Somalia
"ZA", // South Africa
"KR", // South Korea
"SS", // South Sudan
"ES", // Spain
"LK", // Sri Lanka
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"SD", // Sudan
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste (East Timor)
"TG", // Togo
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Turkey
"TM", // Turkmenistan
"TV", // Tuvalu
"UG", // Uganda
"UA", // Ukraine (with certain exceptions)
"AE", // United Arab Emirates
"GB", // United Kingdom
"US", // United States of America
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VN", // Vietnam
"YE", // Yemen
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@ -6,6 +6,12 @@ extend-exclude = [
# File suffixes aren't typos # File suffixes aren't typos
"assets/icons/file_icons/file_types.json", "assets/icons/file_icons/file_types.json",
"crates/extensions_ui/src/extension_suggest.rs", "crates/extensions_ui/src/extension_suggest.rs",
# Some countries codes are flagged as typos.
"crates/anthropic/src/supported_countries.rs",
"crates/google_ai/src/supported_countries.rs",
"crates/open_ai/src/supported_countries.rs",
# Stripe IDs are flagged as typos. # Stripe IDs are flagged as typos.
"crates/collab/src/db/tests/processed_stripe_event_tests.rs", "crates/collab/src/db/tests/processed_stripe_event_tests.rs",
# Not our typos # Not our typos