diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 18e38c6e4c..63527cea1c 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -438,6 +438,13 @@ impl VectorDatabase { .filter_map(|row| row.ok()) .collect::>(); + if deserialized_rows.len() == 0 { + return Ok(Vec::new()); + } + + // Get Length of Embeddings Returned + let embedding_len = deserialized_rows[0].1 .0.len(); + let batch_n = 1000; let mut batches = Vec::new(); let mut batch_ids = Vec::new(); @@ -449,7 +456,8 @@ impl VectorDatabase { if batch_ids.len() == batch_n { let embeddings = std::mem::take(&mut batch_embeddings); let ids = std::mem::take(&mut batch_ids); - let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings); + let array = + Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings); match array { Ok(array) => { batches.push((ids, array)); @@ -460,8 +468,10 @@ impl VectorDatabase { }); if batch_ids.len() > 0 { - let array = - Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone()); + let array = Array2::from_shape_vec( + (batch_ids.len(), embedding_len), + batch_embeddings.clone(), + ); match array { Ok(array) => { batches.push((batch_ids.clone(), array));