2023-06-19 21:15:34 +03:00
|
|
|
from typing import Any, List
|
|
|
|
|
|
|
|
from langchain.docstore.document import Document
|
2023-08-21 13:45:32 +03:00
|
|
|
from langchain.embeddings.base import Embeddings
|
2023-06-19 21:15:34 +03:00
|
|
|
from langchain.vectorstores import SupabaseVectorStore
|
2023-07-10 15:27:49 +03:00
|
|
|
from supabase.client import Client
|
2023-06-19 21:15:34 +03:00
|
|
|
|
|
|
|
|
|
|
|
class CustomSupabaseVectorStore(SupabaseVectorStore):
|
2023-06-22 18:50:06 +03:00
|
|
|
"""A custom vector store that uses the match_vectors table instead of the vectors table."""
|
|
|
|
|
2023-06-28 20:39:27 +03:00
|
|
|
brain_id: str = "none"
|
2023-06-22 18:50:06 +03:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
client: Client,
|
2023-08-21 13:45:32 +03:00
|
|
|
embedding: Embeddings,
|
2023-06-22 18:50:06 +03:00
|
|
|
table_name: str,
|
2023-06-28 20:39:27 +03:00
|
|
|
brain_id: str = "none",
|
2023-06-22 18:50:06 +03:00
|
|
|
):
|
2023-06-19 21:15:34 +03:00
|
|
|
super().__init__(client, embedding, table_name)
|
2023-06-28 20:39:27 +03:00
|
|
|
self.brain_id = brain_id
|
2023-06-22 18:50:06 +03:00
|
|
|
|
2023-06-19 21:15:34 +03:00
|
|
|
def similarity_search(
|
2023-06-22 18:50:06 +03:00
|
|
|
self,
|
|
|
|
query: str,
|
2023-11-02 00:09:04 +03:00
|
|
|
k: int = 20,
|
2023-07-31 22:34:34 +03:00
|
|
|
table: str = "match_vectors",
|
2023-06-22 18:50:06 +03:00
|
|
|
threshold: float = 0.5,
|
2023-06-19 21:15:34 +03:00
|
|
|
**kwargs: Any
|
|
|
|
) -> List[Document]:
|
|
|
|
vectors = self._embedding.embed_documents([query])
|
|
|
|
query_embedding = vectors[0]
|
|
|
|
res = self._client.rpc(
|
|
|
|
table,
|
|
|
|
{
|
|
|
|
"query_embedding": query_embedding,
|
|
|
|
"match_count": k,
|
2023-06-28 20:39:27 +03:00
|
|
|
"p_brain_id": str(self.brain_id),
|
2023-06-19 21:15:34 +03:00
|
|
|
},
|
|
|
|
).execute()
|
|
|
|
|
|
|
|
match_result = [
|
|
|
|
(
|
|
|
|
Document(
|
|
|
|
metadata=search.get("metadata", {}), # type: ignore
|
|
|
|
page_content=search.get("content", ""),
|
|
|
|
),
|
|
|
|
search.get("similarity", 0.0),
|
|
|
|
)
|
|
|
|
for search in res.data
|
|
|
|
if search.get("content")
|
|
|
|
]
|
|
|
|
|
|
|
|
documents = [doc for doc, _ in match_result]
|
|
|
|
|
2023-06-22 18:50:06 +03:00
|
|
|
return documents
|