mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-24 20:03:41 +03:00
refactor: chat_routes (#1512)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
82bcf38b16
commit
e3a99d1ace
0
backend/routes/chat/__init_.py
Normal file
0
backend/routes/chat/__init_.py
Normal file
43
backend/routes/chat/brainful_chat.py
Normal file
43
backend/routes/chat/brainful_chat.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from llm.qa_base import QABaseBrainPicking
|
||||||
|
from routes.authorizations.brain_authorization import validate_brain_authorization
|
||||||
|
from routes.authorizations.types import RoleEnum
|
||||||
|
from routes.chat.interface import ChatInterface
|
||||||
|
|
||||||
|
from repository.brain import get_brain_details
|
||||||
|
|
||||||
|
|
||||||
|
class BrainfulChat(ChatInterface):
|
||||||
|
def validate_authorization(self, user_id, brain_id):
|
||||||
|
if brain_id:
|
||||||
|
validate_brain_authorization(
|
||||||
|
brain_id=brain_id,
|
||||||
|
user_id=user_id,
|
||||||
|
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_openai_api_key(self, brain_id, user_id):
|
||||||
|
brain_details = get_brain_details(brain_id)
|
||||||
|
if brain_details:
|
||||||
|
return brain_details.openai_api_key
|
||||||
|
|
||||||
|
def get_answer_generator(
|
||||||
|
self,
|
||||||
|
brain_id,
|
||||||
|
chat_id,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
user_openai_api_key,
|
||||||
|
streaming,
|
||||||
|
prompt_id,
|
||||||
|
):
|
||||||
|
return QABaseBrainPicking(
|
||||||
|
chat_id=chat_id,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
brain_id=brain_id,
|
||||||
|
user_openai_api_key=user_openai_api_key,
|
||||||
|
streaming=streaming,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
)
|
36
backend/routes/chat/brainless_chat.py
Normal file
36
backend/routes/chat/brainless_chat.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from llm.qa_headless import HeadlessQA
|
||||||
|
from routes.chat.interface import ChatInterface
|
||||||
|
|
||||||
|
from repository.user_identity import get_user_identity
|
||||||
|
|
||||||
|
|
||||||
|
class BrainlessChat(ChatInterface):
|
||||||
|
def validate_authorization(self, user_id, brain_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_openai_api_key(self, brain_id, user_id):
|
||||||
|
user_identity = get_user_identity(user_id)
|
||||||
|
|
||||||
|
if user_identity is not None:
|
||||||
|
return user_identity.openai_api_key
|
||||||
|
|
||||||
|
def get_answer_generator(
|
||||||
|
self,
|
||||||
|
brain_id,
|
||||||
|
chat_id,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
user_openai_api_key,
|
||||||
|
streaming,
|
||||||
|
prompt_id,
|
||||||
|
):
|
||||||
|
return HeadlessQA(
|
||||||
|
chat_id=chat_id,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
user_openai_api_key=user_openai_api_key,
|
||||||
|
streaming=streaming,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
)
|
11
backend/routes/chat/factory.py
Normal file
11
backend/routes/chat/factory.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from .brainful_chat import BrainfulChat
|
||||||
|
from .brainless_chat import BrainlessChat
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_strategy(brain_id: UUID | None = None):
|
||||||
|
if brain_id:
|
||||||
|
return BrainfulChat()
|
||||||
|
else:
|
||||||
|
return BrainlessChat()
|
25
backend/routes/chat/interface.py
Normal file
25
backend/routes/chat/interface.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInterface(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def validate_authorization(self, user_id, required_roles):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_openai_api_key(self, brain_id, user_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_answer_generator(
|
||||||
|
self,
|
||||||
|
brain_id,
|
||||||
|
chat_id,
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
user_openai_api_key,
|
||||||
|
streaming,
|
||||||
|
prompt_id,
|
||||||
|
):
|
||||||
|
pass
|
57
backend/routes/chat/utils.py
Normal file
57
backend/routes/chat/utils.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import time
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from models import UserIdentity, UserUsage
|
||||||
|
from models.databases.supabase.supabase import SupabaseDB
|
||||||
|
|
||||||
|
|
||||||
|
class NullableUUID(UUID):
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v) -> UUID | None:
|
||||||
|
if v == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return UUID(v)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
|
||||||
|
try:
|
||||||
|
supabase_db.delete_chat_history(chat_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
supabase_db.delete_chat(chat_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_requests_limit(
|
||||||
|
user: UserIdentity,
|
||||||
|
):
|
||||||
|
userDailyUsage = UserUsage(
|
||||||
|
id=user.id, email=user.email, openai_api_key=user.openai_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
userSettings = userDailyUsage.get_user_settings()
|
||||||
|
|
||||||
|
date = time.strftime("%Y%m%d")
|
||||||
|
userDailyUsage.handle_increment_user_request_count(date)
|
||||||
|
|
||||||
|
if user.openai_api_key is None:
|
||||||
|
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
|
||||||
|
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||||
|
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from venv import logger
|
from venv import logger
|
||||||
@ -18,8 +17,6 @@ from models import (
|
|||||||
get_supabase_db,
|
get_supabase_db,
|
||||||
)
|
)
|
||||||
from models.databases.supabase.chats import QuestionAndAnswer
|
from models.databases.supabase.chats import QuestionAndAnswer
|
||||||
from models.databases.supabase.supabase import SupabaseDB
|
|
||||||
from repository.brain import get_brain_details
|
|
||||||
from repository.chat import (
|
from repository.chat import (
|
||||||
ChatUpdatableProperties,
|
ChatUpdatableProperties,
|
||||||
CreateChatProperties,
|
CreateChatProperties,
|
||||||
@ -35,64 +32,16 @@ from repository.chat.get_chat_history_with_notifications import (
|
|||||||
get_chat_history_with_notifications,
|
get_chat_history_with_notifications,
|
||||||
)
|
)
|
||||||
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
||||||
from repository.user_identity import get_user_identity
|
from routes.chat.factory import get_chat_strategy
|
||||||
from routes.authorizations.brain_authorization import validate_brain_authorization
|
from routes.chat.utils import (
|
||||||
from routes.authorizations.types import RoleEnum
|
NullableUUID,
|
||||||
|
check_user_requests_limit,
|
||||||
|
delete_chat_from_db,
|
||||||
|
)
|
||||||
|
|
||||||
chat_router = APIRouter()
|
chat_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class NullableUUID(UUID):
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
yield cls.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, v) -> UUID | None:
|
|
||||||
if v == "":
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return UUID(v)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
|
|
||||||
try:
|
|
||||||
supabase_db.delete_chat_history(chat_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
supabase_db.delete_chat(chat_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def check_user_requests_limit(
|
|
||||||
user: UserIdentity,
|
|
||||||
):
|
|
||||||
userDailyUsage = UserUsage(
|
|
||||||
id=user.id, email=user.email, openai_api_key=user.openai_api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
userSettings = userDailyUsage.get_user_settings()
|
|
||||||
|
|
||||||
date = time.strftime("%Y%m%d")
|
|
||||||
userDailyUsage.handle_increment_user_request_count(date)
|
|
||||||
|
|
||||||
if user.openai_api_key is None:
|
|
||||||
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
|
|
||||||
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
|
||||||
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||||
async def healthz():
|
async def healthz():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
@ -186,20 +135,10 @@ async def create_question_handler(
|
|||||||
"""
|
"""
|
||||||
Add a new question to the chat.
|
Add a new question to the chat.
|
||||||
"""
|
"""
|
||||||
if brain_id:
|
|
||||||
validate_brain_authorization(
|
|
||||||
brain_id=brain_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retrieve user's OpenAI API key
|
chat_instance = get_chat_strategy(brain_id)
|
||||||
if brain_id:
|
|
||||||
validate_brain_authorization(
|
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||||
brain_id=brain_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
|
||||||
)
|
|
||||||
|
|
||||||
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
||||||
brain = Brain(id=brain_id)
|
brain = Brain(id=brain_id)
|
||||||
@ -213,17 +152,10 @@ async def create_question_handler(
|
|||||||
userSettings = userDailyUsage.get_user_settings()
|
userSettings = userDailyUsage.get_user_settings()
|
||||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||||
|
|
||||||
if not current_user.openai_api_key and brain_id:
|
|
||||||
brain_details = get_brain_details(brain_id)
|
|
||||||
if brain_details:
|
|
||||||
current_user.openai_api_key = brain_details.openai_api_key
|
|
||||||
|
|
||||||
if not current_user.openai_api_key:
|
if not current_user.openai_api_key:
|
||||||
user_identity = get_user_identity(current_user.id)
|
current_user.openai_api_key = chat_instance.get_openai_api_key(
|
||||||
|
brain_id=brain_id, user_id=current_user.id
|
||||||
if user_identity is not None:
|
)
|
||||||
current_user.openai_api_key = user_identity.openai_api_key
|
|
||||||
|
|
||||||
# Retrieve chat model (temperature, max_tokens, model)
|
# Retrieve chat model (temperature, max_tokens, model)
|
||||||
if (
|
if (
|
||||||
not chat_question.model
|
not chat_question.model
|
||||||
@ -241,25 +173,15 @@ async def create_question_handler(
|
|||||||
check_user_requests_limit(current_user)
|
check_user_requests_limit(current_user)
|
||||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||||
if brain_id:
|
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||||
gpt_answer_generator = QABaseBrainPicking(
|
chat_id=str(chat_id),
|
||||||
chat_id=str(chat_id),
|
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
max_tokens=chat_question.max_tokens,
|
||||||
max_tokens=chat_question.max_tokens,
|
temperature=chat_question.temperature,
|
||||||
temperature=chat_question.temperature,
|
brain_id=str(brain_id),
|
||||||
brain_id=str(brain_id),
|
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
prompt_id=chat_question.prompt_id,
|
||||||
prompt_id=chat_question.prompt_id,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
gpt_answer_generator = HeadlessQA(
|
|
||||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
|
||||||
temperature=chat_question.temperature,
|
|
||||||
max_tokens=chat_question.max_tokens,
|
|
||||||
user_openai_api_key=current_user.openai_api_key,
|
|
||||||
chat_id=str(chat_id),
|
|
||||||
prompt_id=chat_question.prompt_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
||||||
|
|
||||||
@ -287,12 +209,8 @@ async def create_stream_question_handler(
|
|||||||
| None = Query(..., description="The ID of the brain"),
|
| None = Query(..., description="The ID of the brain"),
|
||||||
current_user: UserIdentity = Depends(get_current_user),
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
if brain_id:
|
chat_instance = get_chat_strategy(brain_id)
|
||||||
validate_brain_authorization(
|
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||||
brain_id=brain_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retrieve user's OpenAI API key
|
# Retrieve user's OpenAI API key
|
||||||
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
||||||
@ -305,16 +223,11 @@ async def create_stream_question_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
userSettings = userDailyUsage.get_user_settings()
|
userSettings = userDailyUsage.get_user_settings()
|
||||||
if not current_user.openai_api_key and brain_id:
|
|
||||||
brain_details = get_brain_details(brain_id)
|
|
||||||
if brain_details:
|
|
||||||
current_user.openai_api_key = brain_details.openai_api_key
|
|
||||||
|
|
||||||
if not current_user.openai_api_key:
|
if not current_user.openai_api_key:
|
||||||
user_identity = get_user_identity(current_user.id)
|
current_user.openai_api_key = chat_instance.get_openai_api_key(
|
||||||
|
brain_id=brain_id, user_id=current_user.id
|
||||||
if user_identity is not None:
|
)
|
||||||
current_user.openai_api_key = user_identity.openai_api_key
|
|
||||||
|
|
||||||
# Retrieve chat model (temperature, max_tokens, model)
|
# Retrieve chat model (temperature, max_tokens, model)
|
||||||
if (
|
if (
|
||||||
@ -333,32 +246,19 @@ async def create_stream_question_handler(
|
|||||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||||
# TODO check if model is in the list of models available for the user
|
# TODO check if model is in the list of models available for the user
|
||||||
|
|
||||||
print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore
|
|
||||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||||
|
|
||||||
if brain_id:
|
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||||
gpt_answer_generator = QABaseBrainPicking(
|
chat_id=str(chat_id),
|
||||||
chat_id=str(chat_id),
|
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||||
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
||||||
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
temperature=(brain_details or chat_question).temperature, # type: ignore
|
||||||
temperature=(brain_details or chat_question).temperature, # type: ignore
|
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
brain_id=str(brain_id),
|
streaming=True,
|
||||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
prompt_id=chat_question.prompt_id,
|
||||||
streaming=True,
|
brain_id=str(brain_id),
|
||||||
prompt_id=chat_question.prompt_id,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
gpt_answer_generator = HeadlessQA(
|
|
||||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
|
||||||
temperature=chat_question.temperature,
|
|
||||||
max_tokens=chat_question.max_tokens,
|
|
||||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
|
||||||
chat_id=str(chat_id),
|
|
||||||
streaming=True,
|
|
||||||
prompt_id=chat_question.prompt_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("streaming")
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
gpt_answer_generator.generate_stream(chat_id, chat_question),
|
gpt_answer_generator.generate_stream(chat_id, chat_question),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
Loading…
Reference in New Issue
Block a user