Add ability to use o1-preview and o1-mini as custom models (#17804)

This is a barebones modification of the OpenAI provider code to
accommodate non-streaming completions. This is specifically for the o1
models, which do not support streaming. Tested that this is working by
running a `/workflow` with the following (arbitrarily chosen) settings:

```json
{
  "language_models": {
    "openai": {
      "version": "1",
      "available_models": [
        {
          "name": "o1-preview",
          "display_name": "o1-preview",
          "max_tokens": 128000,
          "max_completion_tokens": 30000
        },
        {
          "name": "o1-mini",
          "display_name": "o1-mini",
          "max_tokens": 128000,
          "max_completion_tokens": 20000
        }
      ]
    }
  },
}
```

Release Notes:

- Changed  `low_speed_timeout_in_seconds` option to `600` for OpenAI
provider to accommodate recent o1 model release.

---------

Co-authored-by: Peter <peter@zed.dev>
Co-authored-by: Bennet <bennet@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
jvmncs 2024-09-13 15:42:15 -04:00 committed by GitHub
parent 1b36c62188
commit c71f052276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 136 additions and 4 deletions

View File

@ -916,7 +916,8 @@
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1"
"api_url": "https://api.openai.com/v1",
"low_speed_timeout_in_seconds": 600
}
},
// Zed's Prettier integration settings.

View File

@ -163,11 +163,13 @@ impl AssistantSettingsContent {
display_name,
max_tokens,
max_output_tokens,
max_completion_tokens: None,
} => Some(open_ai::AvailableModel {
name,
display_name,
max_tokens,
max_output_tokens,
max_completion_tokens: None,
}),
_ => None,
})

View File

@ -2407,7 +2407,7 @@ impl Codegen {
Ok(LanguageModelRequest {
messages,
tools: Vec::new(),
stop: vec!["|END|>".to_string()],
stop: Vec::new(),
temperature: 1.,
})
}

View File

@ -78,6 +78,8 @@ pub struct AvailableModel {
pub max_tokens: usize,
/// The maximum number of output tokens allowed by the model.
pub max_output_tokens: Option<u32>,
/// The maximum number of completion tokens allowed by the model (o1-* only)
pub max_completion_tokens: Option<u32>,
/// Override this model with a different Anthropic model for tool calls.
pub tool_override: Option<String>,
/// Indicates whether this custom model supports caching.
@ -257,6 +259,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),

View File

@ -43,6 +43,7 @@ pub struct AvailableModel {
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
}
pub struct OpenAiLanguageModelProvider {
@ -175,6 +176,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
},
);
}

View File

@ -178,11 +178,13 @@ impl OpenAiSettingsContent {
display_name,
max_tokens,
max_output_tokens,
max_completion_tokens,
} => Some(provider::open_ai::AvailableModel {
name,
max_tokens,
max_output_tokens,
display_name,
max_completion_tokens,
}),
_ => None,
})

View File

@ -1,12 +1,21 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use futures::{
io::BufReader,
stream::{self, BoxStream},
AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
use std::{
convert::TryFrom,
future::{self, Future},
pin::Pin,
time::Duration,
};
use strum::EnumIter;
pub use supported_countries::*;
@ -72,6 +81,7 @@ pub enum Model {
display_name: Option<String>,
max_tokens: usize,
max_output_tokens: Option<u32>,
max_completion_tokens: Option<u32>,
},
}
@ -139,6 +149,7 @@ pub struct Request {
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
@ -263,6 +274,111 @@ pub struct ResponseStreamEvent {
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Choice {
pub index: u32,
pub message: RequestMessage,
pub finish_reason: Option<String>,
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<Response> {
let uri = format!("{api_url}/chat/completions");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
};
let mut request_body = request;
request_body.stream = false;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: Response = serde_json::from_str(&body)?;
Ok(response)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAiResponse {
error: OpenAiError,
}
#[derive(Deserialize)]
struct OpenAiError {
message: String,
}
match serde_json::from_str::<OpenAiResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
ResponseStreamEvent {
created: response.created as u32,
model: response.model,
choices: response
.choices
.into_iter()
.map(|choice| ChoiceDelta {
index: choice.index,
delta: ResponseMessageDelta {
role: Some(match choice.message {
RequestMessage::Assistant { .. } => Role::Assistant,
RequestMessage::User { .. } => Role::User,
RequestMessage::System { .. } => Role::System,
RequestMessage::Tool { .. } => Role::Tool,
}),
content: match choice.message {
RequestMessage::Assistant { content, .. } => content,
RequestMessage::User { content } => Some(content),
RequestMessage::System { content } => Some(content),
RequestMessage::Tool { content, .. } => Some(content),
},
tool_calls: None,
},
finish_reason: choice.finish_reason,
})
.collect(),
usage: Some(response.usage),
}
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
@ -270,6 +386,12 @@ pub async fn stream_completion(
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
if request.model == "o1-preview" || request.model == "o1-mini" {
let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
let response_stream_event = response.map(adapt_response_to_stream);
return Ok(stream::once(future::ready(response_stream_event)).boxed());
}
let uri = format!("{api_url}/chat/completions");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)