mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
add should_truncate to embedding providers
This commit is contained in:
parent
e377ada1a9
commit
76caea80f7
@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage {
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
fn count_tokens(&self, span: &str) -> usize;
|
||||
fn should_truncate(&self, span: &str) -> bool;
|
||||
// fn truncate(&self, span: &str) -> Result<&str>;
|
||||
}
|
||||
|
||||
@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings {
|
||||
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
tokens.len()
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||
|
||||
// let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
// let Ok(output) = {
|
||||
// if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
// tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
// OPENAI_BPE_TOKENIZER.decode(tokens)
|
||||
// } else {
|
||||
// Ok(span)
|
||||
// }
|
||||
// };
|
||||
}
|
||||
}
|
||||
|
||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
tokens.len()
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
span.len()
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
Loading…
Reference in New Issue
Block a user