feat: Refactor get_question_context_for_brain endpoint (#1872)

to return a list of DocumentAnswer objects

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2023-12-12 22:33:23 +01:00 committed by GitHub
parent b09d93e547
commit 36b008e0eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 54 additions and 17 deletions

View File

@ -220,12 +220,14 @@ async def set_brain_as_default(
@brain_router.post( @brain_router.post(
"/brains/{brain_id}/question_context", "/brains/{brain_id}/documents",
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())], dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
tags=["Brain"], tags=["Brain"],
) )
async def get_question_context_for_brain(brain_id: UUID, request: BrainQuestionRequest): async def get_question_context_for_brain(
brain_id: UUID, question: BrainQuestionRequest
):
# TODO: Move this endpoint to AnswerGenerator service # TODO: Move this endpoint to AnswerGenerator service
"""Retrieve the question context from a specific brain.""" """Retrieve the question context from a specific brain."""
context = get_question_context_from_brain(brain_id, request.question) context = get_question_context_from_brain(brain_id, question.question)
return {"context": context} return {"docs": context}

View File

@ -1,5 +1,6 @@
from uuid import UUID from uuid import UUID
from attr import dataclass
from logger import get_logger from logger import get_logger
from models.settings import get_embeddings, get_supabase_client from models.settings import get_embeddings, get_supabase_client
from vectorstore.supabase import CustomSupabaseVectorStore from vectorstore.supabase import CustomSupabaseVectorStore
@ -7,6 +8,16 @@ from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass
class DocumentAnswer:
file_name: str
file_sha1: str
file_size: int
file_url: str = ""
file_id: str = ""
file_similarity: float = 0.0
def get_question_context_from_brain(brain_id: UUID, question: str) -> str: def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
# TODO: Move to AnswerGenerator service # TODO: Move to AnswerGenerator service
supabase_client = get_supabase_client() supabase_client = get_supabase_client()
@ -18,16 +29,22 @@ def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
table_name="vectors", table_name="vectors",
brain_id=str(brain_id), brain_id=str(brain_id),
) )
documents = vector_store.similarity_search(question) documents = vector_store.similarity_search(question, k=20, threshold=0.8)
## I can't pass more than 2500 tokens to as return value in my array. So i need to remove the docs after i reach 2000 tokens. A token equals 1.5 characters. So 2000 tokens is 3000 characters.
tokens = 0
for doc in documents:
tokens += len(doc.page_content) * 1.5
if tokens > 3000:
documents.remove(doc)
logger.info("documents", documents)
logger.info("tokens", tokens)
logger.info("🔥🔥🔥🔥🔥🔥")
# aggregate all the documents into one string ## Create a list of DocumentAnswer objects from the documents but with no duplicates file_sha1
return "\n".join([doc.page_content for doc in documents]) answers = []
file_sha1s = []
for document in documents:
if document.metadata["file_sha1"] not in file_sha1s:
file_sha1s.append(document.metadata["file_sha1"])
answers.append(
DocumentAnswer(
file_name=document.metadata["file_name"],
file_sha1=document.metadata["file_sha1"],
file_size=document.metadata["file_size"],
file_id=document.metadata["id"],
file_similarity=document.metadata["similarity"],
)
)
return answers

View File

@ -43,7 +43,11 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
match_result = [ match_result = [
( (
Document( Document(
metadata=search.get("metadata", {}), # type: ignore metadata={
**search.get("metadata", {}),
"id": search.get("id", ""),
"similarity": search.get("similarity", 0.0),
},
page_content=search.get("content", ""), page_content=search.get("content", ""),
), ),
search.get("similarity", 0.0), search.get("similarity", 0.0),

View File

@ -139,3 +139,14 @@ export const updateBrainSecrets = async (
): Promise<void> => { ): Promise<void> => {
await axiosInstance.put(`/brains/${brainId}/secrets-values`, secrets); await axiosInstance.put(`/brains/${brainId}/secrets-values`, secrets);
}; };
export const getDocsFromQuestion = async (
brainId: string,
question: string,
axiosInstance: AxiosInstance
): Promise<string[]> => {
return (await axiosInstance.post<Record<"docs",string[]>>(`/brains/${brainId}/documents`, {
question,
})).data.docs;
}

View File

@ -8,6 +8,7 @@ import {
getBrains, getBrains,
getBrainUsers, getBrainUsers,
getDefaultBrain, getDefaultBrain,
getDocsFromQuestion,
getPublicBrains, getPublicBrains,
setAsDefaultBrain, setAsDefaultBrain,
Subscription, Subscription,
@ -48,6 +49,8 @@ export const useBrainApi = () => {
updateBrain: async (brainId: string, brain: UpdateBrainInput) => updateBrain: async (brainId: string, brain: UpdateBrainInput) =>
updateBrain(brainId, brain, axiosInstance), updateBrain(brainId, brain, axiosInstance),
getPublicBrains: async () => getPublicBrains(axiosInstance), getPublicBrains: async () => getPublicBrains(axiosInstance),
getDocsFromQuestion: async (brainId: string, question: string) =>
getDocsFromQuestion(brainId, question, axiosInstance),
updateBrainSecrets: async ( updateBrainSecrets: async (
brainId: string, brainId: string,
secrets: Record<string, string> secrets: Record<string, string>