corrected batching order and managed for open ai embedding errors

This commit is contained in:
KCaverly 2023-07-06 11:11:39 -04:00
parent afccf608f4
commit a86b6c42c7
3 changed files with 169 additions and 173 deletions

View File

@ -32,6 +32,7 @@ async-trait.workspace = true
bincode = "1.3.3"
matrixmultiply = "0.3.7"
tiktoken-rs = "0.5.0"
rand.workspace = true
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View File

@ -2,15 +2,20 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::serde_json;
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
use std::{env, time::Instant};
use std::time::Duration;
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
@ -60,56 +65,80 @@ impl EmbeddingProvider for DummyEmbeddings {
}
}
// impl OpenAIEmbeddings {
// async fn truncate(span: &str) -> String {
// let bpe = cl100k_base().unwrap();
// let mut tokens = bpe.encode_with_special_tokens(span);
// if tokens.len() > 8192 {
// tokens.truncate(8192);
// let result = bpe.decode(tokens);
// if result.is_ok() {
// return result.unwrap();
// }
// }
impl OpenAIEmbeddings {
async fn truncate(span: String) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
if tokens.len() > 8190 {
tokens.truncate(8190);
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
if result.is_ok() {
let transformed = result.unwrap();
// assert_ne!(transformed, span);
return transformed;
}
}
// return span.to_string();
// }
// }
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// Truncate spans to 8192 if needed
// let t0 = Instant::now();
// let mut truncated_spans = vec![];
// for span in spans {
// truncated_spans.push(Self::truncate(span));
// }
// let spans = futures::future::join_all(truncated_spans).await;
// log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
return span.to_string();
}
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans,
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
let mut response = self.client.send(request).await?;
if !response.status().is_success() {
return Err(anyhow!("openai embedding failed {}", response.status()));
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
const MAX_RETRIES: usize = 3;
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
let mut response: Response<AsyncBody>;
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
.send_request(api_key, spans.iter().map(|x| &**x).collect())
.await?;
request_number += 1;
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
return Err(anyhow!(
"openai max retries, error: {:?}",
&response.status()
));
}
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
std::thread::sleep(delay);
}
StatusCode::BAD_REQUEST => {
log::info!("BAD REQUEST: {:?}", &response.status());
// Don't worry about delaying bad request, as we can assume
// we haven't been rate limited yet.
for span in spans.iter_mut() {
*span = Self::truncate(span.to_string()).await;
}
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
@ -118,11 +147,18 @@ impl EmbeddingProvider for OpenAIEmbeddings {
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
Ok(response
return Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.collect())
.collect());
}
_ => {
return Err(anyhow!("openai embedding failed {}", response.status()));
}
}
}
Err(anyhow!("openai embedding failed"))
}
}

View File

@ -74,7 +74,6 @@ pub fn init(
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
move |event, cx| {
let t0 = Instant::now();
let workspace = &event.0;
if let Some(workspace) = workspace.upgrade(cx) {
let project = workspace.read(cx).project().clone();
@ -126,9 +125,7 @@ pub struct VectorStore {
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbWrite>,
// embed_batch_tx: channel::Sender<Vec<(i64, IndexedFile, Vec<String>)>>,
batch_files_tx: channel::Sender<(i64, IndexedFile, Vec<String>)>,
parsing_files_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc<Language>, SystemTime)>,
_db_update_task: Task<()>,
_embed_batch_task: Vec<Task<()>>,
_batch_files_task: Task<()>,
@ -220,14 +217,13 @@ impl VectorStore {
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
let mut _embed_batch_task = Vec::new();
for _ in 0..cx.background().num_cpus() {
for _ in 0..1 {
//cx.background().num_cpus() {
let db_update_tx = db_update_tx.clone();
let embed_batch_rx = embed_batch_rx.clone();
let embedding_provider = embedding_provider.clone();
_embed_batch_task.push(cx.background().spawn(async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
log::info!("Embedding Batch! ");
// Construct Batch
let mut embeddings_queue = embeddings_queue.clone();
let mut document_spans = vec![];
@ -235,20 +231,20 @@ impl VectorStore {
document_spans.extend(document_span);
}
if let Some(mut embeddings) = embedding_provider
if let Ok(embeddings) = embedding_provider
.embed_batch(document_spans.iter().map(|x| &**x).collect())
.await
.log_err()
{
let mut i = 0;
let mut j = 0;
while let Some(embedding) = embeddings.pop() {
for embedding in embeddings.iter() {
while embeddings_queue[i].1.documents.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1.documents[j].embedding = embedding;
embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
j += 1;
}
@ -283,7 +279,6 @@ impl VectorStore {
while let Ok((worktree_id, indexed_file, document_spans)) =
batch_files_rx.recv().await
{
log::info!("Batching File: {:?}", &indexed_file.path);
queue_len += &document_spans.len();
embeddings_queue.push((worktree_id, indexed_file, document_spans));
if queue_len >= EMBEDDINGS_BATCH_SIZE {
@ -338,10 +333,7 @@ impl VectorStore {
embedding_provider,
language_registry,
db_update_tx,
// embed_batch_tx,
batch_files_tx,
parsing_files_tx,
parsing_files_rx,
_db_update_task,
_embed_batch_task,
_batch_files_task,
@ -449,8 +441,6 @@ impl VectorStore {
let database_url = self.database_url.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
let parsing_files_rx = self.parsing_files_rx.clone();
let batch_files_tx = self.batch_files_tx.clone();
cx.spawn(|this, mut cx| async move {
let t0 = Instant::now();
@ -553,37 +543,6 @@ impl VectorStore {
})
.detach();
// cx.background()
// .scoped(|scope| {
// for _ in 0..cx.background().num_cpus() {
// scope.spawn(async {
// let mut parser = Parser::new();
// let mut cursor = QueryCursor::new();
// while let Ok((worktree_id, file_path, language, mtime)) =
// parsing_files_rx.recv().await
// {
// log::info!("Parsing File: {:?}", &file_path);
// if let Some((indexed_file, document_spans)) = Self::index_file(
// &mut cursor,
// &mut parser,
// &fs,
// language,
// file_path.clone(),
// mtime,
// )
// .await
// .log_err()
// {
// batch_files_tx
// .try_send((worktree_id, indexed_file, document_spans))
// .unwrap();
// }
// }
// });
// }
// })
// .await;
this.update(&mut cx, |this, cx| {
// The below is managing for updated on save
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
@ -592,90 +551,90 @@ impl VectorStore {
if let Some(project_state) = this.projects.get(&project.downgrade()) {
let worktree_db_ids = project_state.worktree_db_ids.clone();
// if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
// {
// // Iterate through changes
// let language_registry = this.language_registry.clone();
if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
{
// Iterate through changes
let language_registry = this.language_registry.clone();
// let db =
// VectorDatabase::new(this.database_url.to_string_lossy().into());
// if db.is_err() {
// return;
// }
// let db = db.unwrap();
let db =
VectorDatabase::new(this.database_url.to_string_lossy().into());
if db.is_err() {
return;
}
let db = db.unwrap();
// let worktree_db_id: Option<i64> = {
// let mut found_db_id = None;
// for (w_id, db_id) in worktree_db_ids.into_iter() {
// if &w_id == worktree_id {
// found_db_id = Some(db_id);
// }
// }
let worktree_db_id: Option<i64> = {
let mut found_db_id = None;
for (w_id, db_id) in worktree_db_ids.into_iter() {
if &w_id == worktree_id {
found_db_id = Some(db_id);
}
}
// found_db_id
// };
found_db_id
};
// if worktree_db_id.is_none() {
// return;
// }
// let worktree_db_id = worktree_db_id.unwrap();
if worktree_db_id.is_none() {
return;
}
let worktree_db_id = worktree_db_id.unwrap();
// let file_mtimes = db.get_file_mtimes(worktree_db_id);
// if file_mtimes.is_err() {
// return;
// }
let file_mtimes = db.get_file_mtimes(worktree_db_id);
if file_mtimes.is_err() {
return;
}
// let file_mtimes = file_mtimes.unwrap();
// let paths_tx = this.paths_tx.clone();
let file_mtimes = file_mtimes.unwrap();
let parsing_files_tx = this.parsing_files_tx.clone();
// smol::block_on(async move {
// for change in changes.into_iter() {
// let change_path = change.0.clone();
// log::info!("Change: {:?}", &change_path);
// if let Ok(language) = language_registry
// .language_for_file(&change_path.to_path_buf(), None)
// .await
// {
// if language
// .grammar()
// .and_then(|grammar| grammar.embedding_config.as_ref())
// .is_none()
// {
// continue;
// }
smol::block_on(async move {
for change in changes.into_iter() {
let change_path = change.0.clone();
log::info!("Change: {:?}", &change_path);
if let Ok(language) = language_registry
.language_for_file(&change_path.to_path_buf(), None)
.await
{
if language
.grammar()
.and_then(|grammar| grammar.embedding_config.as_ref())
.is_none()
{
continue;
}
// // TODO: Make this a bit more defensive
// let modified_time =
// change_path.metadata().unwrap().modified().unwrap();
// let existing_time =
// file_mtimes.get(&change_path.to_path_buf());
// let already_stored =
// existing_time.map_or(false, |existing_time| {
// if &modified_time != existing_time
// && existing_time.elapsed().unwrap().as_secs()
// > REINDEXING_DELAY
// {
// false
// } else {
// true
// }
// });
// TODO: Make this a bit more defensive
let modified_time =
change_path.metadata().unwrap().modified().unwrap();
let existing_time =
file_mtimes.get(&change_path.to_path_buf());
let already_stored =
existing_time.map_or(false, |existing_time| {
if &modified_time != existing_time
&& existing_time.elapsed().unwrap().as_secs()
> REINDEXING_DELAY
{
false
} else {
true
}
});
// if !already_stored {
// log::info!("Need to reindex: {:?}", &change_path);
// paths_tx
// .try_send((
// worktree_db_id,
// change_path.to_path_buf(),
// language,
// modified_time,
// ))
// .unwrap();
// }
// }
// }
// })
// }
if !already_stored {
log::info!("Need to reindex: {:?}", &change_path);
parsing_files_tx
.try_send((
worktree_db_id,
change_path.to_path_buf(),
language,
modified_time,
))
.unwrap();
}
}
}
})
}
}
});