feat: 🎸 rag

now works with 30 chunks
This commit is contained in:
Stan Girard 2024-01-25 20:19:56 -08:00
parent 67c71bbc1d
commit e7bd571ac5
6 changed files with 14 additions and 10 deletions

View File

@ -60,7 +60,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str # pyright: ignore reportPrivateUsage=none
max_tokens: int = 256
max_tokens: int = 2000
streaming: bool = False
knowledge_qa: Optional[RAGInterface]
metadata: Optional[dict] = None

View File

@ -60,7 +60,7 @@ class QuivrRAG(BaseModel, RAGInterface):
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 256
max_tokens: int = 2000
streaming: bool = False
@property
@ -91,6 +91,7 @@ class QuivrRAG(BaseModel, RAGInterface):
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
max_tokens: int = 2000,
**kwargs,
):
super().__init__(
@ -103,6 +104,7 @@ class QuivrRAG(BaseModel, RAGInterface):
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
self.max_tokens = max_tokens
def _create_supabase_client(self) -> Client:
return create_client(

View File

@ -59,9 +59,7 @@ class Repository(ABC):
pass
@abstractmethod
def similarity_search(
self, query_embedding, table: str, top_k: int, threshold: float
):
def similarity_search(self, query_embedding, table: str, k: int, threshold: float):
pass
@abstractmethod

View File

@ -30,12 +30,12 @@ class Vector(Repository):
return response
# TODO: remove duplicate similarity_search in supabase vector store
def similarity_search(self, query_embedding, table, top_k, threshold):
def similarity_search(self, query_embedding, table, k, threshold):
response = self.db.rpc(
table,
{
"query_embedding": query_embedding,
"match_count": top_k,
"match_count": k,
"match_threshold": threshold,
},
).execute()

View File

@ -29,6 +29,7 @@ def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
embeddings,
table_name="vectors",
brain_id=str(brain_id),
number_docs=20,
)
documents = vector_store.similarity_search(question, k=20, threshold=0.8)

View File

@ -14,6 +14,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
brain_id: str = "none"
user_id: str = "none"
number_docs: int = 35
def __init__(
self,
@ -22,10 +23,12 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
table_name: str,
brain_id: str = "none",
user_id: str = "none",
number_docs: int = 35,
):
super().__init__(client, embedding, table_name)
self.brain_id = brain_id
self.user_id = user_id
self.number_docs = number_docs
def find_brain_closest_query(
self,
@ -42,7 +45,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
table,
{
"query_embedding": query_embedding,
"match_count": k,
"match_count": self.number_docs,
"p_user_id": str(self.user_id),
},
).execute()
@ -62,7 +65,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
def similarity_search(
self,
query: str,
k: int = 6,
k: int = 35,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any,
@ -73,7 +76,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
table,
{
"query_embedding": query_embedding,
"match_count": k,
"match_count": self.number_docs,
"p_brain_id": str(self.brain_id),
},
).execute()