mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-29 13:21:43 +03:00
move embedding truncation to base model
This commit is contained in:
parent
2b780ee7b2
commit
4e90e45999
@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send {
|
||||
fn is_authenticated(&self) -> bool;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel {
|
||||
length: usize,
|
||||
direction: crate::models::TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if content.len() < length {
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
let truncated = match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.iter()
|
||||
@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider {
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
8190
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let truncated = span.chars().collect::<Vec<char>>()[..8190]
|
||||
.iter()
|
||||
.collect::<String>();
|
||||
(truncated, 8190)
|
||||
}
|
||||
}
|
||||
|
@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage {
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddingProvider {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
OPENAI_BPE_TOKENIZER
|
||||
.decode(tokens.clone())
|
||||
.ok()
|
||||
.unwrap_or_else(|| span.to_string())
|
||||
} else {
|
||||
span.to_string()
|
||||
};
|
||||
// fn truncate(&self, span: &str) -> (String, usize) {
|
||||
// let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
// let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
// tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
// OPENAI_BPE_TOKENIZER
|
||||
// .decode(tokens.clone())
|
||||
// .ok()
|
||||
// .unwrap_or_else(|| span.to_string())
|
||||
// } else {
|
||||
// span.to_string()
|
||||
// };
|
||||
|
||||
(output, tokens.len())
|
||||
}
|
||||
// (output, tokens.len())
|
||||
// }
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
|
@ -1,4 +1,7 @@
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::TruncationDirection,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
@ -108,7 +111,14 @@ impl CodeContextRetriever {
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
@ -131,7 +141,15 @@ impl CodeContextRetriever {
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
@ -222,8 +240,13 @@ impl CodeContextRetriever {
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
let (document_content, token_count) =
|
||||
self.embedding_provider.truncate(&document_content);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_content = model.truncate(
|
||||
&document_content,
|
||||
model.capacity()?,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_content)?;
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
|
@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
(span.to_string(), 1)
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
200
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user