add should_truncate to embedding providers

This commit is contained in:
KCaverly 2023-08-30 11:58:45 -04:00
parent e377ada1a9
commit 76caea80f7
2 changed files with 23 additions and 0 deletions

View File

@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage {
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
fn count_tokens(&self, span: &str) -> usize;
fn should_truncate(&self, span: &str) -> bool;
// fn truncate(&self, span: &str) -> Result<&str>;
}
@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings {
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
tokens.len()
}
fn should_truncate(&self, span: &str) -> bool {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
// let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
// let Ok(output) = {
// if tokens.len() > OPENAI_INPUT_LIMIT {
// tokens.truncate(OPENAI_INPUT_LIMIT);
// OPENAI_BPE_TOKENIZER.decode(tokens)
// } else {
// Ok(span)
// }
// };
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
tokens.len()
}
fn should_truncate(&self, span: &str) -> bool {
self.count_tokens(span) > OPENAI_INPUT_LIMIT
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;

View File

@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
span.len()
}
fn should_truncate(&self, span: &str) -> bool {
false
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);