feat(brain): add endpoint to return context to question (#1044)

This commit is contained in:
Joey Wang 2023-08-27 00:38:41 -07:00 committed by GitHub
parent d7a508acdd
commit 30cb91531f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 14 deletions

View File

@ -43,6 +43,10 @@ class BrainUpdatableProperties(BaseModel):
return brain_dict return brain_dict
class BrainQuestionRequest(BaseModel):
question: str
class Brain(Repository): class Brain(Repository):
def __init__(self, supabase_client): def __init__(self, supabase_client):
self.db = supabase_client self.db = supabase_client

View File

@ -10,3 +10,4 @@ from .update_user_rights import update_brain_user_rights
from .get_default_user_brain import get_user_default_brain from .get_default_user_brain import get_user_default_brain
from .set_as_default_brain_for_user import set_as_default_brain_for_user from .set_as_default_brain_for_user import set_as_default_brain_for_user
from .get_default_user_brain_or_create_new import get_default_user_brain_or_create_new from .get_default_user_brain_or_create_new import get_default_user_brain_or_create_new
from .get_question_context_from_brain import get_question_context_from_brain

View File

@ -0,0 +1,20 @@
from uuid import UUID
from models.settings import get_embeddings, get_supabase_client
from vectorstore.supabase import CustomSupabaseVectorStore
def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
supabase_client = get_supabase_client()
embeddings = get_embeddings()
vector_store = CustomSupabaseVectorStore(
supabase_client,
embeddings,
table_name="vectors",
brain_id=brain_id,
)
documents = vector_store.similarity_search(question)
# aggregate all the documents into one string
return "\n".join([doc.page_content for doc in documents])

View File

@ -4,20 +4,16 @@ from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from logger import get_logger from logger import get_logger
from models import BrainRateLimiting, UserIdentity from models import BrainRateLimiting, UserIdentity
from models.databases.supabase.brains import ( from models.databases.supabase.brains import (BrainQuestionRequest,
BrainUpdatableProperties, BrainUpdatableProperties,
CreateBrainProperties, CreateBrainProperties)
) from repository.brain import (create_brain, create_brain_user,
from repository.brain import ( get_brain_details,
create_brain, get_default_user_brain_or_create_new,
create_brain_user, get_question_context_from_brain, get_user_brains,
get_brain_details, get_user_default_brain,
get_default_user_brain_or_create_new, set_as_default_brain_for_user,
get_user_brains, update_brain_by_id)
get_user_default_brain,
set_as_default_brain_for_user,
update_brain_by_id,
)
from repository.prompt import delete_prompt_by_id, get_prompt_by_id from repository.prompt import delete_prompt_by_id, get_prompt_by_id
from routes.authorizations.brain_authorization import has_brain_authorization from routes.authorizations.brain_authorization import has_brain_authorization
from routes.authorizations.types import RoleEnum from routes.authorizations.types import RoleEnum
@ -207,3 +203,26 @@ async def set_as_default_brain_endpoint(
set_as_default_brain_for_user(user.id, brain_id) set_as_default_brain_for_user(user.id, brain_id)
return {"message": f"Brain {brain_id} has been set as default brain."} return {"message": f"Brain {brain_id} has been set as default brain."}
@brain_router.post(
"/brains/{brain_id}/question_context",
dependencies=[
Depends(
AuthBearer(),
),
Depends(has_brain_authorization()),
],
tags=["Brain"],
)
async def get_question_context_from_brain_endpoint(
brain_id: UUID,
request: BrainQuestionRequest,
):
"""
Get question context from brain
"""
context = get_question_context_from_brain(brain_id, request.question)
return {"context": context}