diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 1385b32b4d..9806877660 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { async fn embed_batch( &self, spans: Vec, - _credential: ProviderCredential, + credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let api_key = match credential { + ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), + _ => Err(anyhow!("no api key provided")), + }?; let mut request_number = 0; let mut rate_limiting = false; @@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 299aa328b5..6f792c78e2 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - provider_credential: ProviderCredential, + pub provider_credential: ProviderCredential, } #[derive(Clone)] diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f420e0503b..7fb5f749b4 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -281,15 +281,13 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let credential = self.provider_credential.clone(); - match credential { - ProviderCredential::NoCredentials => { - let credential = self.embedding_provider.retrieve_credentials(cx); - self.provider_credential = credential; - } - _ => {} - } + let existing_credential = self.provider_credential.clone(); + let credential = match existing_credential { + ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), + _ => existing_credential, + }; + self.provider_credential = credential.clone(); self.embedding_queue.lock().set_credential(credential); self.is_authenticated() } @@ -1020,14 +1018,11 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { - println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } - println!("SHOULD NOW BE AUTHENTICATED"); - if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {