leverage embeddings len returned in construction matrix multiplication

This commit is contained in:
KCaverly 2023-09-27 10:33:04 -04:00
parent 3682751455
commit 0e6fd645fd

View File

@ -438,6 +438,13 @@ impl VectorDatabase {
.filter_map(|row| row.ok())
.collect::<Vec<(usize, Embedding)>>();
if deserialized_rows.len() == 0 {
return Ok(Vec::new());
}
// Get Length of Embeddings Returned
let embedding_len = deserialized_rows[0].1 .0.len();
let batch_n = 1000;
let mut batches = Vec::new();
let mut batch_ids = Vec::new();
@ -449,7 +456,8 @@ impl VectorDatabase {
if batch_ids.len() == batch_n {
let embeddings = std::mem::take(&mut batch_embeddings);
let ids = std::mem::take(&mut batch_ids);
let array = Array2::from_shape_vec((batch_ids.len(), 1536), embeddings);
let array =
Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
match array {
Ok(array) => {
batches.push((ids, array));
@ -460,8 +468,10 @@ impl VectorDatabase {
});
if batch_ids.len() > 0 {
let array =
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
let array = Array2::from_shape_vec(
(batch_ids.len(), embedding_len),
batch_embeddings.clone(),
);
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));