quivr/backend/routes/chat/brainful_chat.py
Stan Girard 6a041b6f6d
feat: 🎸 openai (#1658)
cleaning old code to introduce better patern

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
2023-11-20 01:22:03 +01:00

69 lines
1.9 KiB
Python

from fastapi import HTTPException
from llm.api_brain_qa import APIBrainQA
from llm.qa_base import QABaseBrainPicking
from models.brain_entity import BrainType
from repository.brain.get_brain_by_id import get_brain_by_id
from routes.authorizations.brain_authorization import validate_brain_authorization
from routes.authorizations.types import RoleEnum
from routes.chat.interface import ChatInterface
models_supporting_function_calls = [
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]
class BrainfulChat(ChatInterface):
def validate_authorization(self, user_id, brain_id):
if brain_id:
validate_brain_authorization(
brain_id=brain_id,
user_id=user_id,
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
)
def get_answer_generator(
self,
brain_id,
chat_id,
model,
max_tokens,
temperature,
streaming,
prompt_id,
user_id,
):
brain = get_brain_by_id(brain_id)
if not brain:
raise HTTPException(status_code=404, detail="Brain not found")
if (
brain.brain_type == BrainType.DOC
or model not in models_supporting_function_calls
):
return QABaseBrainPicking(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
brain_id=brain_id,
streaming=streaming,
prompt_id=prompt_id,
)
return APIBrainQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
temperature=temperature,
brain_id=brain_id,
streaming=streaming,
prompt_id=prompt_id,
user_id=user_id,
)