mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-10 05:37:29 +03:00
collab: Remove LLM completions over RPC (#16114)
This PR removes the LLM completion messages from the RPC protocol, as these now go through the LLM service as of #16113. Release Notes: - N/A
This commit is contained in:
parent
f992cfdc7f
commit
f952126319
@ -105,18 +105,6 @@ impl<R: RequestMessage> Response<R> {
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingResponse<R: RequestMessage> {
|
||||
peer: Arc<Peer>,
|
||||
receipt: Receipt<R>,
|
||||
}
|
||||
|
||||
impl<R: RequestMessage> StreamingResponse<R> {
|
||||
fn send(&self, payload: R::Response) -> Result<()> {
|
||||
self.peer.respond(self.receipt, payload)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Principal {
|
||||
User(User),
|
||||
@ -630,31 +618,6 @@ impl Server {
|
||||
))
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
complete_with_language_model(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_streaming_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
stream_complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
&app_state.config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
@ -948,40 +911,6 @@ impl Server {
|
||||
})
|
||||
}
|
||||
|
||||
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
|
||||
Fut: Send + Future<Output = Result<()>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
self.add_handler(move |envelope, session| {
|
||||
let receipt = envelope.receipt();
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
let peer = session.peer.clone();
|
||||
let response = StreamingResponse {
|
||||
peer: peer.clone(),
|
||||
receipt,
|
||||
};
|
||||
match (handler)(envelope.payload, response, session).await {
|
||||
Ok(()) => {
|
||||
peer.end_stream(receipt)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let proto_err = match &error {
|
||||
Error::Internal(err) => err.to_proto(),
|
||||
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
|
||||
};
|
||||
peer.respond_with_error(receipt, proto_err)?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ZedProCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: Response<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
anthropic::complete(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
|
||||
response.send(proto::CompleteWithLanguageModelResponse {
|
||||
completion: serde_json::to_string(&result)?,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_complete_with_language_model(
|
||||
request: proto::StreamCompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let mut chunks = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = chunks.next().await {
|
||||
let chunk = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&chunk)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
||||
let api_key = config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let mut events = google_ai::stream_generate_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Zed) => {
|
||||
let api_key = config
|
||||
.qwen2_7b_api_key
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B API key configured on the server")?;
|
||||
let api_url = config
|
||||
.qwen2_7b_api_url
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B URL configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
&api_url,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
None => return Err(anyhow!("unknown provider"))?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn count_language_model_tokens(
|
||||
request: proto::CountLanguageModelTokens,
|
||||
response: Response<proto::CountLanguageModelTokens>,
|
||||
|
@ -197,10 +197,6 @@ message Envelope {
|
||||
|
||||
JoinHostedProject join_hosted_project = 164;
|
||||
|
||||
CompleteWithLanguageModel complete_with_language_model = 226;
|
||||
CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
|
||||
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
|
||||
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
|
||||
CountLanguageModelTokens count_language_model_tokens = 230;
|
||||
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
||||
GetCachedEmbeddings get_cached_embeddings = 189;
|
||||
@ -279,7 +275,7 @@ message Envelope {
|
||||
|
||||
reserved 158 to 161;
|
||||
reserved 166 to 169;
|
||||
reserved 224 to 225;
|
||||
reserved 224 to 229;
|
||||
}
|
||||
|
||||
// Messages
|
||||
@ -2084,24 +2080,6 @@ enum LanguageModelRole {
|
||||
reserved 3;
|
||||
}
|
||||
|
||||
message CompleteWithLanguageModel {
|
||||
LanguageModelProvider provider = 1;
|
||||
string request = 2;
|
||||
}
|
||||
|
||||
message CompleteWithLanguageModelResponse {
|
||||
string completion = 1;
|
||||
}
|
||||
|
||||
message StreamCompleteWithLanguageModel {
|
||||
LanguageModelProvider provider = 1;
|
||||
string request = 2;
|
||||
}
|
||||
|
||||
message StreamCompleteWithLanguageModelResponse {
|
||||
string event = 1;
|
||||
}
|
||||
|
||||
message CountLanguageModelTokens {
|
||||
LanguageModelProvider provider = 1;
|
||||
string request = 2;
|
||||
|
@ -298,10 +298,6 @@ messages!(
|
||||
(PrepareRename, Background),
|
||||
(PrepareRenameResponse, Background),
|
||||
(ProjectEntryResponse, Foreground),
|
||||
(CompleteWithLanguageModel, Background),
|
||||
(CompleteWithLanguageModelResponse, Background),
|
||||
(StreamCompleteWithLanguageModel, Background),
|
||||
(StreamCompleteWithLanguageModelResponse, Background),
|
||||
(CountLanguageModelTokens, Background),
|
||||
(CountLanguageModelTokensResponse, Background),
|
||||
(RefreshInlayHints, Foreground),
|
||||
@ -476,11 +472,6 @@ request_messages!(
|
||||
(PerformRename, PerformRenameResponse),
|
||||
(Ping, Ack),
|
||||
(PrepareRename, PrepareRenameResponse),
|
||||
(CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
|
||||
(
|
||||
StreamCompleteWithLanguageModel,
|
||||
StreamCompleteWithLanguageModelResponse
|
||||
),
|
||||
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
|
||||
(RefreshInlayHints, Ack),
|
||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||
|
Loading…
Reference in New Issue
Block a user