quivr/backend/vectorstore/supabase.py

101 lines
2.9 KiB
Python
Raw Normal View History

2023-06-19 21:15:34 +03:00
from typing import Any, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import SupabaseVectorStore
from logger import get_logger
from supabase.client import Client
2023-06-19 21:15:34 +03:00
logger = get_logger(__name__)
2023-06-19 21:15:34 +03:00
class CustomSupabaseVectorStore(SupabaseVectorStore):
"""A custom vector store that uses the match_vectors table instead of the vectors table."""
2023-06-28 20:39:27 +03:00
brain_id: str = "none"
user_id: str = "none"
number_docs: int = 35
max_input: int = 2000
def __init__(
self,
client: Client,
embedding: Embeddings,
table_name: str,
2023-06-28 20:39:27 +03:00
brain_id: str = "none",
user_id: str = "none",
number_docs: int = 35,
max_input: int = 2000,
):
2023-06-19 21:15:34 +03:00
super().__init__(client, embedding, table_name)
2023-06-28 20:39:27 +03:00
self.brain_id = brain_id
self.user_id = user_id
self.number_docs = number_docs
self.max_input = max_input
def find_brain_closest_query(
self,
user_id: str,
query: str,
k: int = 6,
table: str = "match_brain",
threshold: float = 0.5,
) -> [dict]:
vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
table,
{
"query_embedding": query_embedding,
"match_count": self.number_docs,
"p_user_id": str(self.user_id),
},
).execute()
# Get the brain_id of the brain that is most similar to the query
# Get the brain_id and name of the brains that are most similar to the query
brain_details = [
{
"id": item.get("id", None),
"name": item.get("name", None),
"similarity": item.get("similarity", 0.0),
}
for item in res.data
]
return brain_details
2023-06-19 21:15:34 +03:00
def similarity_search(
self,
query: str,
k: int = 40,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any,
2023-06-19 21:15:34 +03:00
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
table,
{
"query_embedding": query_embedding,
"max_chunk_sum": self.max_input,
2023-06-28 20:39:27 +03:00
"p_brain_id": str(self.brain_id),
2023-06-19 21:15:34 +03:00
},
).execute()
match_result = [
Document(
metadata={
**search.get("metadata", {}),
"id": search.get("id", ""),
"similarity": search.get("similarity", 0.0),
},
page_content=search.get("content", ""),
2023-06-19 21:15:34 +03:00
)
for search in res.data
if search.get("content")
]
return match_result