mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-24 03:41:56 +03:00
refactor: chat_routes (#1512)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
82bcf38b16
commit
e3a99d1ace
0
backend/routes/chat/__init_.py
Normal file
0
backend/routes/chat/__init_.py
Normal file
43
backend/routes/chat/brainful_chat.py
Normal file
43
backend/routes/chat/brainful_chat.py
Normal file
@ -0,0 +1,43 @@
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from routes.authorizations.brain_authorization import validate_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
from routes.chat.interface import ChatInterface
|
||||
|
||||
from repository.brain import get_brain_details
|
||||
|
||||
|
||||
class BrainfulChat(ChatInterface):
|
||||
def validate_authorization(self, user_id, brain_id):
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=user_id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
|
||||
def get_openai_api_key(self, brain_id, user_id):
|
||||
brain_details = get_brain_details(brain_id)
|
||||
if brain_details:
|
||||
return brain_details.openai_api_key
|
||||
|
||||
def get_answer_generator(
|
||||
self,
|
||||
brain_id,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
user_openai_api_key,
|
||||
streaming,
|
||||
prompt_id,
|
||||
):
|
||||
return QABaseBrainPicking(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
brain_id=brain_id,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
36
backend/routes/chat/brainless_chat.py
Normal file
36
backend/routes/chat/brainless_chat.py
Normal file
@ -0,0 +1,36 @@
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from routes.chat.interface import ChatInterface
|
||||
|
||||
from repository.user_identity import get_user_identity
|
||||
|
||||
|
||||
class BrainlessChat(ChatInterface):
|
||||
def validate_authorization(self, user_id, brain_id):
|
||||
pass
|
||||
|
||||
def get_openai_api_key(self, brain_id, user_id):
|
||||
user_identity = get_user_identity(user_id)
|
||||
|
||||
if user_identity is not None:
|
||||
return user_identity.openai_api_key
|
||||
|
||||
def get_answer_generator(
|
||||
self,
|
||||
brain_id,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
user_openai_api_key,
|
||||
streaming,
|
||||
prompt_id,
|
||||
):
|
||||
return HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
11
backend/routes/chat/factory.py
Normal file
11
backend/routes/chat/factory.py
Normal file
@ -0,0 +1,11 @@
|
||||
from uuid import UUID
|
||||
|
||||
from .brainful_chat import BrainfulChat
|
||||
from .brainless_chat import BrainlessChat
|
||||
|
||||
|
||||
def get_chat_strategy(brain_id: UUID | None = None):
|
||||
if brain_id:
|
||||
return BrainfulChat()
|
||||
else:
|
||||
return BrainlessChat()
|
25
backend/routes/chat/interface.py
Normal file
25
backend/routes/chat/interface.py
Normal file
@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatInterface(ABC):
|
||||
@abstractmethod
|
||||
def validate_authorization(self, user_id, required_roles):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_openai_api_key(self, brain_id, user_id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_answer_generator(
|
||||
self,
|
||||
brain_id,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
user_openai_api_key,
|
||||
streaming,
|
||||
prompt_id,
|
||||
):
|
||||
pass
|
57
backend/routes/chat/utils.py
Normal file
57
backend/routes/chat/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models import UserIdentity, UserUsage
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
|
||||
|
||||
class NullableUUID(UUID):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v) -> UUID | None:
|
||||
if v == "":
|
||||
return None
|
||||
try:
|
||||
return UUID(v)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
|
||||
try:
|
||||
supabase_db.delete_chat_history(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
try:
|
||||
supabase_db.delete_chat(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
user: UserIdentity,
|
||||
):
|
||||
userDailyUsage = UserUsage(
|
||||
id=user.id, email=user.email, openai_api_key=user.openai_api_key
|
||||
)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
userDailyUsage.handle_increment_user_request_count(date)
|
||||
|
||||
if user.openai_api_key is None:
|
||||
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
|
||||
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
|
||||
raise HTTPException(
|
||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
else:
|
||||
pass
|
@ -1,4 +1,3 @@
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from venv import logger
|
||||
@ -18,8 +17,6 @@ from models import (
|
||||
get_supabase_db,
|
||||
)
|
||||
from models.databases.supabase.chats import QuestionAndAnswer
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
from repository.brain import get_brain_details
|
||||
from repository.chat import (
|
||||
ChatUpdatableProperties,
|
||||
CreateChatProperties,
|
||||
@ -35,64 +32,16 @@ from repository.chat.get_chat_history_with_notifications import (
|
||||
get_chat_history_with_notifications,
|
||||
)
|
||||
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
||||
from repository.user_identity import get_user_identity
|
||||
from routes.authorizations.brain_authorization import validate_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
from routes.chat.factory import get_chat_strategy
|
||||
from routes.chat.utils import (
|
||||
NullableUUID,
|
||||
check_user_requests_limit,
|
||||
delete_chat_from_db,
|
||||
)
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
|
||||
class NullableUUID(UUID):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v) -> UUID | None:
|
||||
if v == "":
|
||||
return None
|
||||
try:
|
||||
return UUID(v)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
|
||||
try:
|
||||
supabase_db.delete_chat_history(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
try:
|
||||
supabase_db.delete_chat(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
user: UserIdentity,
|
||||
):
|
||||
userDailyUsage = UserUsage(
|
||||
id=user.id, email=user.email, openai_api_key=user.openai_api_key
|
||||
)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
userDailyUsage.handle_increment_user_request_count(date)
|
||||
|
||||
if user.openai_api_key is None:
|
||||
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
|
||||
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
|
||||
raise HTTPException(
|
||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
@ -186,20 +135,10 @@ async def create_question_handler(
|
||||
"""
|
||||
Add a new question to the chat.
|
||||
"""
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=current_user.id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
|
||||
# Retrieve user's OpenAI API key
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=current_user.id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
chat_instance = get_chat_strategy(brain_id)
|
||||
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
brain = Brain(id=brain_id)
|
||||
@ -213,17 +152,10 @@ async def create_question_handler(
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
|
||||
if not current_user.openai_api_key and brain_id:
|
||||
brain_details = get_brain_details(brain_id)
|
||||
if brain_details:
|
||||
current_user.openai_api_key = brain_details.openai_api_key
|
||||
|
||||
if not current_user.openai_api_key:
|
||||
user_identity = get_user_identity(current_user.id)
|
||||
|
||||
if user_identity is not None:
|
||||
current_user.openai_api_key = user_identity.openai_api_key
|
||||
|
||||
current_user.openai_api_key = chat_instance.get_openai_api_key(
|
||||
brain_id=brain_id, user_id=current_user.id
|
||||
)
|
||||
# Retrieve chat model (temperature, max_tokens, model)
|
||||
if (
|
||||
not chat_question.model
|
||||
@ -241,25 +173,15 @@ async def create_question_handler(
|
||||
check_user_requests_limit(current_user)
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||
if brain_id:
|
||||
gpt_answer_generator = QABaseBrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature,
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
else:
|
||||
gpt_answer_generator = HeadlessQA(
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_openai_api_key=current_user.openai_api_key,
|
||||
chat_id=str(chat_id),
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature,
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
|
||||
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
||||
|
||||
@ -287,12 +209,8 @@ async def create_stream_question_handler(
|
||||
| None = Query(..., description="The ID of the brain"),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=current_user.id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
chat_instance = get_chat_strategy(brain_id)
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
# Retrieve user's OpenAI API key
|
||||
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
@ -305,16 +223,11 @@ async def create_stream_question_handler(
|
||||
)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
if not current_user.openai_api_key and brain_id:
|
||||
brain_details = get_brain_details(brain_id)
|
||||
if brain_details:
|
||||
current_user.openai_api_key = brain_details.openai_api_key
|
||||
|
||||
if not current_user.openai_api_key:
|
||||
user_identity = get_user_identity(current_user.id)
|
||||
|
||||
if user_identity is not None:
|
||||
current_user.openai_api_key = user_identity.openai_api_key
|
||||
current_user.openai_api_key = chat_instance.get_openai_api_key(
|
||||
brain_id=brain_id, user_id=current_user.id
|
||||
)
|
||||
|
||||
# Retrieve chat model (temperature, max_tokens, model)
|
||||
if (
|
||||
@ -333,32 +246,19 @@ async def create_stream_question_handler(
|
||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||
# TODO check if model is in the list of models available for the user
|
||||
|
||||
print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
|
||||
if brain_id:
|
||||
gpt_answer_generator = QABaseBrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
||||
temperature=(brain_details or chat_question).temperature, # type: ignore
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
else:
|
||||
gpt_answer_generator = HeadlessQA(
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
chat_id=str(chat_id),
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
||||
temperature=(brain_details or chat_question).temperature, # type: ignore
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
brain_id=str(brain_id),
|
||||
)
|
||||
|
||||
print("streaming")
|
||||
return StreamingResponse(
|
||||
gpt_answer_generator.generate_stream(chat_id, chat_question),
|
||||
media_type="text/event-stream",
|
||||
|
Loading…
Reference in New Issue
Block a user