refactor: chat_routes (#1512)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Zineb El Bachiri 2023-10-30 10:18:23 +01:00 committed by GitHub
parent 82bcf38b16
commit e3a99d1ace
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 136 deletions

View File

View File

@ -0,0 +1,43 @@
from llm.qa_base import QABaseBrainPicking
from routes.authorizations.brain_authorization import validate_brain_authorization
from routes.authorizations.types import RoleEnum
from routes.chat.interface import ChatInterface
from repository.brain import get_brain_details
class BrainfulChat(ChatInterface):
def validate_authorization(self, user_id, brain_id):
if brain_id:
validate_brain_authorization(
brain_id=brain_id,
user_id=user_id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)
def get_openai_api_key(self, brain_id, user_id):
brain_details = get_brain_details(brain_id)
if brain_details:
return brain_details.openai_api_key
def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
return QABaseBrainPicking(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
prompt_id=prompt_id,
)

View File

@ -0,0 +1,36 @@
from llm.qa_headless import HeadlessQA
from routes.chat.interface import ChatInterface
from repository.user_identity import get_user_identity
class BrainlessChat(ChatInterface):
def validate_authorization(self, user_id, brain_id):
pass
def get_openai_api_key(self, brain_id, user_id):
user_identity = get_user_identity(user_id)
if user_identity is not None:
return user_identity.openai_api_key
def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
return HeadlessQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
prompt_id=prompt_id,
)

View File

@ -0,0 +1,11 @@
from uuid import UUID
from .brainful_chat import BrainfulChat
from .brainless_chat import BrainlessChat
def get_chat_strategy(brain_id: UUID | None = None):
if brain_id:
return BrainfulChat()
else:
return BrainlessChat()

View File

@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
class ChatInterface(ABC):
@abstractmethod
def validate_authorization(self, user_id, required_roles):
pass
@abstractmethod
def get_openai_api_key(self, brain_id, user_id):
pass
@abstractmethod
def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
user_openai_api_key,
streaming,
prompt_id,
):
pass

View File

@ -0,0 +1,57 @@
import time
from uuid import UUID
from fastapi import HTTPException
from models import UserIdentity, UserUsage
from models.databases.supabase.supabase import SupabaseDB
class NullableUUID(UUID):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v) -> UUID | None:
if v == "":
return None
try:
return UUID(v)
except ValueError:
return None
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
try:
supabase_db.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
supabase_db.delete_chat(chat_id)
except Exception as e:
print(e)
pass
def check_user_requests_limit(
user: UserIdentity,
):
userDailyUsage = UserUsage(
id=user.id, email=user.email, openai_api_key=user.openai_api_key
)
userSettings = userDailyUsage.get_user_settings()
date = time.strftime("%Y%m%d")
userDailyUsage.handle_increment_user_request_count(date)
if user.openai_api_key is None:
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
)
else:
pass

View File

@ -1,4 +1,3 @@
import time
from typing import List, Optional
from uuid import UUID
from venv import logger
@ -18,8 +17,6 @@ from models import (
get_supabase_db,
)
from models.databases.supabase.chats import QuestionAndAnswer
from models.databases.supabase.supabase import SupabaseDB
from repository.brain import get_brain_details
from repository.chat import (
ChatUpdatableProperties,
CreateChatProperties,
@ -35,64 +32,16 @@ from repository.chat.get_chat_history_with_notifications import (
get_chat_history_with_notifications,
)
from repository.notification.remove_chat_notifications import remove_chat_notifications
from repository.user_identity import get_user_identity
from routes.authorizations.brain_authorization import validate_brain_authorization
from routes.authorizations.types import RoleEnum
from routes.chat.factory import get_chat_strategy
from routes.chat.utils import (
NullableUUID,
check_user_requests_limit,
delete_chat_from_db,
)
chat_router = APIRouter()
class NullableUUID(UUID):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v) -> UUID | None:
if v == "":
return None
try:
return UUID(v)
except ValueError:
return None
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
try:
supabase_db.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
supabase_db.delete_chat(chat_id)
except Exception as e:
print(e)
pass
def check_user_requests_limit(
user: UserIdentity,
):
userDailyUsage = UserUsage(
id=user.id, email=user.email, openai_api_key=user.openai_api_key
)
userSettings = userDailyUsage.get_user_settings()
date = time.strftime("%Y%m%d")
userDailyUsage.handle_increment_user_request_count(date)
if user.openai_api_key is None:
daily_chat_credit = userSettings.get("daily_chat_credit", 0)
if int(userDailyUsage.daily_requests_count) >= int(daily_chat_credit):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
)
else:
pass
@chat_router.get("/chat/healthz", tags=["Health"])
async def healthz():
return {"status": "ok"}
@ -186,20 +135,10 @@ async def create_question_handler(
"""
Add a new question to the chat.
"""
if brain_id:
validate_brain_authorization(
brain_id=brain_id,
user_id=current_user.id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)
# Retrieve user's OpenAI API key
if brain_id:
validate_brain_authorization(
brain_id=brain_id,
user_id=current_user.id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)
chat_instance = get_chat_strategy(brain_id)
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
@ -213,17 +152,10 @@ async def create_question_handler(
userSettings = userDailyUsage.get_user_settings()
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
current_user.openai_api_key = brain_details.openai_api_key
if not current_user.openai_api_key:
user_identity = get_user_identity(current_user.id)
if user_identity is not None:
current_user.openai_api_key = user_identity.openai_api_key
current_user.openai_api_key = chat_instance.get_openai_api_key(
brain_id=brain_id, user_id=current_user.id
)
# Retrieve chat model (temperature, max_tokens, model)
if (
not chat_question.model
@ -241,8 +173,7 @@ async def create_question_handler(
check_user_requests_limit(current_user)
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
if brain_id:
gpt_answer_generator = QABaseBrainPicking(
gpt_answer_generator = chat_instance.get_answer_generator(
chat_id=str(chat_id),
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
max_tokens=chat_question.max_tokens,
@ -251,15 +182,6 @@ async def create_question_handler(
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
prompt_id=chat_question.prompt_id,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.openai_api_key,
chat_id=str(chat_id),
prompt_id=chat_question.prompt_id,
)
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
@ -287,12 +209,8 @@ async def create_stream_question_handler(
| None = Query(..., description="The ID of the brain"),
current_user: UserIdentity = Depends(get_current_user),
) -> StreamingResponse:
if brain_id:
validate_brain_authorization(
brain_id=brain_id,
user_id=current_user.id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)
chat_instance = get_chat_strategy(brain_id)
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
# Retrieve user's OpenAI API key
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
@ -305,16 +223,11 @@ async def create_stream_question_handler(
)
userSettings = userDailyUsage.get_user_settings()
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
current_user.openai_api_key = brain_details.openai_api_key
if not current_user.openai_api_key:
user_identity = get_user_identity(current_user.id)
if user_identity is not None:
current_user.openai_api_key = user_identity.openai_api_key
current_user.openai_api_key = chat_instance.get_openai_api_key(
brain_id=brain_id, user_id=current_user.id
)
# Retrieve chat model (temperature, max_tokens, model)
if (
@ -333,32 +246,19 @@ async def create_stream_question_handler(
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
# TODO check if model is in the list of models available for the user
print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
if brain_id:
gpt_answer_generator = QABaseBrainPicking(
gpt_answer_generator = chat_instance.get_answer_generator(
chat_id=str(chat_id),
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
temperature=(brain_details or chat_question).temperature, # type: ignore
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
prompt_id=chat_question.prompt_id,
brain_id=str(brain_id),
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
prompt_id=chat_question.prompt_id,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
prompt_id=chat_question.prompt_id,
)
print("streaming")
return StreamingResponse(
gpt_answer_generator.generate_stream(chat_id, chat_question),
media_type="text/event-stream",