collab: Remove LLM completions over RPC (#16114)

This PR removes the LLM completion messages from the RPC protocol, as
these now go through the LLM service as of #16113.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-12 10:08:56 -04:00 committed by GitHub
parent f992cfdc7f
commit f952126319
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1 additions and 299 deletions

View File

@ -105,18 +105,6 @@ impl<R: RequestMessage> Response<R> {
}
}
struct StreamingResponse<R: RequestMessage> {
peer: Arc<Peer>,
receipt: Receipt<R>,
}
impl<R: RequestMessage> StreamingResponse<R> {
fn send(&self, payload: R::Response) -> Result<()> {
self.peer.respond(self.receipt, payload)?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub enum Principal {
User(User),
@ -630,31 +618,6 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
complete_with_language_model(request, response, session, &app_state.config)
.await
}
}
})
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
stream_complete_with_language_model(
request,
response,
session,
&app_state.config,
)
.await
}
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
@ -948,40 +911,6 @@ impl Server {
})
}
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
self.add_handler(move |envelope, session| {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
let peer = session.peer.clone();
let response = StreamingResponse {
peer: peer.clone(),
receipt,
};
match (handler)(envelope.payload, response, session).await {
Ok(()) => {
peer.end_stream(receipt)?;
Ok(())
}
Err(error) => {
let proto_err = match &error {
Error::Internal(err) => err.to_proto(),
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
};
peer.respond_with_error(receipt, proto_err)?;
Err(error)
}
}
}
})
}
#[allow(clippy::too_many_arguments)]
pub fn handle_connection(
self: &Arc<Self>,
@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version(
Ok(())
}
struct ZedProCompleteWithLanguageModelRateLimit;
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
fn capacity(&self) -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"zed-pro:complete-with-language-model"
}
}
struct FreeCompleteWithLanguageModelRateLimit;
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
fn capacity(&self) -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:complete-with-language-model"
}
}
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
response: Response<proto::CompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
anthropic::complete(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CompleteWithLanguageModelResponse {
completion: serde_json::to_string(&result)?,
})?;
Ok(())
}
async fn stream_complete_with_language_model(
request: proto::StreamCompleteWithLanguageModel,
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = chunks.next().await {
let chunk = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let mut events = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Zed) => {
let api_key = config
.qwen2_7b_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = config
.qwen2_7b_api_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
&api_url,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
None => return Err(anyhow!("unknown provider"))?,
}
Ok(())
}
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,

View File

@ -197,10 +197,6 @@ message Envelope {
JoinHostedProject join_hosted_project = 164;
CompleteWithLanguageModel complete_with_language_model = 226;
CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
CountLanguageModelTokens count_language_model_tokens = 230;
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
GetCachedEmbeddings get_cached_embeddings = 189;
@ -279,7 +275,7 @@ message Envelope {
reserved 158 to 161;
reserved 166 to 169;
reserved 224 to 225;
reserved 224 to 229;
}
// Messages
@ -2084,24 +2080,6 @@ enum LanguageModelRole {
reserved 3;
}
message CompleteWithLanguageModel {
LanguageModelProvider provider = 1;
string request = 2;
}
message CompleteWithLanguageModelResponse {
string completion = 1;
}
message StreamCompleteWithLanguageModel {
LanguageModelProvider provider = 1;
string request = 2;
}
message StreamCompleteWithLanguageModelResponse {
string event = 1;
}
message CountLanguageModelTokens {
LanguageModelProvider provider = 1;
string request = 2;

View File

@ -298,10 +298,6 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(CompleteWithLanguageModel, Background),
(CompleteWithLanguageModelResponse, Background),
(StreamCompleteWithLanguageModel, Background),
(StreamCompleteWithLanguageModelResponse, Background),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
(RefreshInlayHints, Foreground),
@ -476,11 +472,6 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
(CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
(
StreamCompleteWithLanguageModel,
StreamCompleteWithLanguageModelResponse
),
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),