From 0e6fd645fd71bf77d1bdff28c30985ac23229aaf Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 27 Sep 2023 10:33:04 -0400 Subject: [PATCH] leverage embeddings len returned in construction matrix multiplication --- crates/semantic_index/src/db.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 18e38c6e4c..63527cea1c 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -438,6 +438,13 @@ impl VectorDatabase { .filter_map(|row| row.ok()) .collect::>(); + if deserialized_rows.len() == 0 { + return Ok(Vec::new()); + } + + // Get Length of Embeddings Returned + let embedding_len = deserialized_rows[0].1 .0.len(); + let batch_n = 1000; let mut batches = Vec::new(); let mut batch_ids = Vec::new(); @@ -449,7 +456,8 @@ impl VectorDatabase { if batch_ids.len() == batch_n { let embeddings = std::mem::take(&mut batch_embeddings); let ids = std::mem::take(&mut batch_ids); - let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings); + let array = + Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings); match array { Ok(array) => { batches.push((ids, array)); @@ -460,8 +468,10 @@ impl VectorDatabase { }); if batch_ids.len() > 0 { - let array = - Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + let array = Array2::from_shape_vec( + (batch_ids.len(), embedding_len), + batch_embeddings.clone(), + ); match array { Ok(array) => { batches.push((batch_ids.clone(), array));