refactor: chat for multibrains (#1812)

# Description

- Chat Module
- External Api Secrets Interface, exposed through brain service
This commit is contained in:
Zineb El Bachiri 2023-12-04 18:38:54 +01:00 committed by GitHub
parent 8ddb7708fd
commit 436e49a5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 516 additions and 582 deletions

View File

@ -12,8 +12,8 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from logger import get_logger
from middlewares.cors import add_cors_middleware
from modules.chat.controller import chat_router
from modules.misc.controller import misc_router
from routes.chat_routes import chat_router
logger = get_logger(__name__)

View File

@ -3,24 +3,25 @@ from typing import Optional
from uuid import UUID
from fastapi import HTTPException
from logger import get_logger
from litellm import completion
from llm.qa_base import QABaseBrainPicking
from llm.utils.call_brain_api import call_brain_api
from llm.utils.get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema,
)
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from logger import get_logger
from modules.brain.service.brain_service import BrainService
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
brain_service = BrainService()
chat_service = ChatService()
logger = get_logger(__name__)
class APIBrainQA(
QABaseBrainPicking,
):
@ -54,7 +55,7 @@ class APIBrainQA(
messages,
functions,
brain_id: UUID,
recursive_count = 0,
recursive_count=0,
):
if recursive_count > 5:
yield "🧠<Deciding what to do>🧠"
@ -80,13 +81,16 @@ class APIBrainQA(
finish_reason = chunk.choices[0].finish_reason
if finish_reason == "stop":
break
if "function_call" in chunk.choices[0].delta and chunk.choices[0].delta["function_call"]:
if (
"function_call" in chunk.choices[0].delta
and chunk.choices[0].delta["function_call"]
):
if chunk.choices[0].delta["function_call"].name:
function_call["name"] = chunk.choices[0].delta["function_call"].name
if chunk.choices[0].delta["function_call"].arguments:
function_call["arguments"] += chunk.choices[0].delta[
"function_call"
].arguments
function_call["arguments"] += (
chunk.choices[0].delta["function_call"].arguments
)
elif finish_reason == "function_call":
try:
@ -154,7 +158,7 @@ class APIBrainQA(
messages = [{"role": "system", "content": prompt_content}]
history = get_chat_history(self.chat_id)
history = chat_service.get_chat_history(self.chat_id)
for message in history:
formatted_message = [
@ -165,7 +169,7 @@ class APIBrainQA(
messages.append({"role": "user", "content": question.question})
streamed_chat_history = update_chat_history(
streamed_chat_history = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
@ -203,7 +207,7 @@ class APIBrainQA(
for token in response_tokens
if not token.startswith("🧠<") and not token.endswith(">🧠")
]
update_message_by_id(
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant="".join(response_tokens),

View File

@ -15,21 +15,17 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from llm.utils.format_chat_history import format_chat_history
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from pydantic import BaseModel
from repository.chat import (
GetChatHistoryOutput,
format_chat_history,
get_chat_history,
update_chat_history,
update_message_by_id,
)
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
@ -40,6 +36,7 @@ QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you
brain_service = BrainService()
chat_service = ChatService()
class QABaseBrainPicking(BaseModel):
@ -155,7 +152,6 @@ class QABaseBrainPicking(BaseModel):
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url
return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
@ -163,7 +159,7 @@ class QABaseBrainPicking(BaseModel):
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base= api_base
api_base=api_base,
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
@ -192,7 +188,9 @@ class QABaseBrainPicking(BaseModel):
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
transformed_history = format_chat_history(
chat_service.get_chat_history(self.chat_id)
)
answering_llm = self._create_llm(
model=self.model,
streaming=False,
@ -230,7 +228,7 @@ class QABaseBrainPicking(BaseModel):
answer = model_response["answer"]
new_chat = update_chat_history(
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
@ -264,7 +262,7 @@ class QABaseBrainPicking(BaseModel):
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
history = get_chat_history(self.chat_id)
history = chat_service.get_chat_history(self.chat_id)
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
@ -323,7 +321,7 @@ class QABaseBrainPicking(BaseModel):
if question.brain_id:
brain = brain_service.get_brain_by_id(question.brain_id)
streamed_chat_history = update_chat_history(
streamed_chat_history = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
@ -386,7 +384,7 @@ class QABaseBrainPicking(BaseModel):
assistant += sources_string
try:
update_message_by_id(
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,

View File

@ -7,27 +7,25 @@ from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from models import BrainSettings # Importing settings related to the 'brain'
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from modules.prompt.entity.prompt import Prompt
from pydantic import BaseModel
from repository.chat import (
GetChatHistoryOutput,
from llm.utils.format_chat_history import (
format_chat_history,
format_history_to_openai_mesages,
get_chat_history,
update_chat_history,
update_message_by_id,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from modules.prompt.entity.prompt import Prompt
from pydantic import BaseModel
logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
chat_service = ChatService()
class HeadlessQA(BaseModel):
@ -104,7 +102,10 @@ class HeadlessQA(BaseModel):
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
# Move format_chat_history to chat service ?
transformed_history = format_chat_history(
chat_service.get_chat_history(self.chat_id)
)
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
)
@ -120,7 +121,7 @@ class HeadlessQA(BaseModel):
model_prediction = answering_llm.predict_messages(messages)
answer = model_prediction.content
new_chat = update_chat_history(
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
@ -152,7 +153,9 @@ class HeadlessQA(BaseModel):
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
transformed_history = format_chat_history(get_chat_history(self.chat_id))
transformed_history = format_chat_history(
chat_service.get_chat_history(self.chat_id)
)
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
)
@ -186,7 +189,7 @@ class HeadlessQA(BaseModel):
),
)
streamed_chat_history = update_chat_history(
streamed_chat_history = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
@ -221,7 +224,7 @@ class HeadlessQA(BaseModel):
await run
assistant = "".join(response_tokens)
update_message_by_id(
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,

View File

@ -5,10 +5,12 @@ from llm.utils.extract_api_brain_definition_values_from_llm_output import (
extract_api_brain_definition_values_from_llm_output,
)
from llm.utils.make_api_request import get_api_call_response_as_text
from modules.brain.service.brain_service import BrainService
from repository.api_brain_definition.get_api_brain_definition import (
get_api_brain_definition,
)
from repository.external_api_secret.read_secret import read_secret
brain_service = BrainService()
def call_brain_api(brain_id: UUID, user_id: UUID, arguments: dict) -> str:
@ -31,7 +33,7 @@ def call_brain_api(brain_id: UUID, user_id: UUID, arguments: dict) -> str:
secrets_values = {}
for secret in secrets:
secret_value = read_secret(
secret_value = brain_service.external_api_secrets_repository.read_secret(
user_id=user_id, brain_id=brain_id, secret_name=secret.name
)
secrets_values[secret.name] = secret_value

View File

@ -1,6 +1,4 @@
import os
from fastapi import FastAPI
if __name__ == "__main__":
# import needed here when running main.py to debug backend
@ -14,6 +12,7 @@ from fastapi.responses import JSONResponse
from logger import get_logger
from middlewares.cors import add_cors_middleware
from modules.api_key.controller import api_key_router
from modules.chat.controller import chat_router
from modules.contact_support.controller import contact_router
from modules.knowledge.controller import knowledge_router
from modules.misc.controller import misc_router
@ -22,12 +21,10 @@ from modules.onboarding.controller import onboarding_router
from modules.prompt.controller import prompt_router
from modules.upload.controller import upload_router
from modules.user.controller import user_router
from packages.utils import handle_request_validation_error
from routes.brain_routes import brain_router
from routes.chat_routes import chat_router
from routes.crawl_routes import crawl_router
from routes.subscription_routes import subscription_router
from logger import get_logger
from packages.utils import handle_request_validation_error
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration

View File

@ -1,18 +1,6 @@
from .brains_subscription_invitations import BrainSubscription
from .chat import Chat, ChatHistory
from .chats import ChatMessage, ChatQuestion
from .files import File
from .settings import (BrainRateLimiting, BrainSettings, ResendSettings,
get_documents_vector_store, get_embeddings,
get_supabase_client, get_supabase_db)
from .user_usage import UserUsage
# TODO uncomment the below import when start using SQLalchemy
# from .sqlalchemy_repository import (
# User,
# Brain,
# BrainUser,
# BrainVector,
# BrainSubscriptionInvitation,
# ApiKey
# )

View File

@ -54,50 +54,6 @@ class Repository(ABC):
):
pass
@abstractmethod
def create_chat(self, new_chat):
pass
@abstractmethod
def get_chat_by_id(self, chat_id: str):
pass
@abstractmethod
def get_chat_history(self, chat_id: str):
pass
@abstractmethod
def get_user_chats(self, user_id: str):
pass
@abstractmethod
def update_chat_history(self, chat_id: str, user_message: str, assistant: str):
pass
@abstractmethod
def update_chat(self, chat_id: UUID, updates):
pass
@abstractmethod
def add_question_and_answer(self, chat_id: str, question_and_answer):
pass
@abstractmethod
def update_message_by_id(self, message_id: UUID, updates):
pass
@abstractmethod
def get_chat_details(self, chat_id: UUID):
pass
@abstractmethod
def delete_chat(self, chat_id: UUID):
pass
@abstractmethod
def delete_chat_history(self, chat_id: UUID):
pass
@abstractmethod
def get_vectors_by_file_name(self, file_name: str):
pass

View File

@ -1,7 +1,6 @@
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions
from models.databases.supabase.brains_subscription_invitations import \
BrainSubscription
from models.databases.supabase.chats import Chats
from models.databases.supabase.files import File
from models.databases.supabase.user_usage import UserUsage
from models.databases.supabase.vectors import Vector

View File

@ -2,7 +2,6 @@ from logger import get_logger
from models.databases.supabase import (
ApiBrainDefinitions,
BrainSubscription,
Chats,
File,
UserUsage,
Vector,
@ -15,7 +14,6 @@ class SupabaseDB(
UserUsage,
File,
BrainSubscription,
Chats,
Vector,
ApiBrainDefinitions,
):
@ -24,6 +22,5 @@ class SupabaseDB(
UserUsage.__init__(self, supabase_client)
File.__init__(self, supabase_client)
BrainSubscription.__init__(self, supabase_client)
Chats.__init__(self, supabase_client)
Vector.__init__(self, supabase_client)
ApiBrainDefinitions.__init__(self, supabase_client)

View File

@ -1,14 +1,14 @@
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from pydantic import BaseSettings
from supabase.client import Client, create_client
from vectorstore.supabase import SupabaseVectorStore
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
logger = get_logger(__name__)
class BrainRateLimiting(BaseSettings):
max_brain_per_user: int = 5

View File

@ -5,7 +5,6 @@ from models.settings import get_supabase_client
from modules.brain.dto.inputs import BrainUpdatableProperties
from modules.brain.entity.brain_entity import BrainEntity, PublicBrain
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
from repository.external_api_secret.utils import build_secret_unique_name
logger = get_logger(__name__)
@ -97,15 +96,3 @@ class Brains(BrainsInterface):
return None
return BrainEntity(**response[0])
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
response = self.db.rpc(
"delete_secret",
{
"secret_name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
},
).execute()
return response.data

View File

@ -0,0 +1,60 @@
from uuid import UUID
from models.settings import get_supabase_client
from modules.brain.repository.interfaces.external_api_secrets_interface import (
ExternalApiSecretsInterface,
)
def build_secret_unique_name(user_id: UUID, brain_id: UUID, secret_name: str):
return f"{user_id}-{brain_id}-{secret_name}"
class ExternalApiSecrets(ExternalApiSecretsInterface):
def __init__(self):
supabase_client = get_supabase_client()
self.db = supabase_client
def create_secret(
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
) -> UUID | None:
response = self.db.rpc(
"insert_secret",
{
"name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
"secret": secret_value,
},
).execute()
return response.data
def read_secret(
self,
user_id: UUID,
brain_id: UUID,
secret_name: str,
) -> UUID | None:
response = self.db.rpc(
"read_secret",
{
"secret_name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
},
).execute()
return response.data
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
response = self.db.rpc(
"delete_secret",
{
"secret_name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
},
).execute()
return response.data

View File

@ -56,10 +56,3 @@ class BrainsInterface(ABC):
Get a brain by id
"""
pass
@abstractmethod
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
"""
Delete a secret from a brain
"""
pass

View File

@ -1,41 +1,29 @@
from abc import ABC, abstractmethod
from typing import List
from uuid import UUID
# TODO: Replace BrainsVectors with KnowledgeVectors interface instead
class BrainsVectorsInterface(ABC):
class ExternalApiSecretsInterface(ABC):
@abstractmethod
def create_brain_vector(self, brain_id, vector_id, file_sha1):
def create_secret(
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
) -> UUID | None:
"""
Create a brain vector
Create a new secret for the API Request in given brain
"""
pass
@abstractmethod
def get_vector_ids_from_file_sha1(self, file_sha1: str):
def read_secret(
self, user_id: UUID, brain_id: UUID, secret_name: str
) -> UUID | None:
"""
Get vector ids from file sha1
Read a secret for the API Request in given brain
"""
pass
@abstractmethod
def get_brain_vector_ids(self, brain_id) -> List[UUID]:
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
"""
Get brain vector ids
"""
pass
@abstractmethod
def delete_file_from_brain(self, brain_id, file_name: str):
"""
Delete file from brain
"""
pass
@abstractmethod
def delete_brain_vector(self, brain_id: str):
"""
Delete brain vector
Delete a secret from a brain
"""
pass

View File

@ -7,6 +7,7 @@ from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrai
from modules.brain.repository.brains import Brains
from modules.brain.repository.brains_users import BrainsUsers
from modules.brain.repository.brains_vectors import BrainsVectors
from modules.brain.repository.external_api_secrets import ExternalApiSecrets
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
from modules.brain.repository.interfaces.brains_users_interface import (
BrainsUsersInterface,
@ -14,6 +15,9 @@ from modules.brain.repository.interfaces.brains_users_interface import (
from modules.brain.repository.interfaces.brains_vectors_interface import (
BrainsVectorsInterface,
)
from modules.brain.repository.interfaces.external_api_secrets_interface import (
ExternalApiSecretsInterface,
)
from modules.knowledge.service.knowledge_service import KnowledgeService
from repository.api_brain_definition.add_api_brain_definition import (
add_api_brain_definition,
@ -27,7 +31,6 @@ from repository.api_brain_definition.get_api_brain_definition import (
from repository.api_brain_definition.update_api_brain_definition import (
update_api_brain_definition,
)
from repository.external_api_secret.create_secret import create_secret
knowledge_service = KnowledgeService()
@ -36,11 +39,13 @@ class BrainService:
brain_repository: BrainsInterface
brain_user_repository: BrainsUsersInterface
brain_vector_repository: BrainsVectorsInterface
external_api_secrets_repository: ExternalApiSecretsInterface
def __init__(self):
self.brain_repository = Brains()
self.brain_user_repository = BrainsUsers()
self.brain_vector = BrainsVectors()
self.external_api_secrets_repository = ExternalApiSecrets()
def get_brain_by_id(self, brain_id: UUID):
return self.brain_repository.get_brain_by_id(brain_id)
@ -75,7 +80,7 @@ class BrainService:
secrets_values = brain.brain_secrets_values
for secret_name in secrets_values:
create_secret(
self.external_api_secrets_repository.create_secret(
user_id=user_id,
brain_id=created_brain.brain_id,
secret_name=secret_name,
@ -96,7 +101,7 @@ class BrainService:
brain_users = self.brain_user_repository.get_brain_users(brain_id=brain_id)
for user in brain_users:
for secret in secrets:
self.brain_repository.delete_secret(
self.external_api_secrets_repository.delete_secret(
user_id=user.user_id,
brain_id=brain_id,
secret_name=secret.name,
@ -217,12 +222,12 @@ class BrainService:
secret_value: str,
) -> None:
"""Update an existing secret."""
self.brain_repository.delete_secret(
self.external_api_secrets_repository.delete_secret(
user_id=user_id,
brain_id=brain_id,
secret_name=secret_name,
)
create_secret(
self.external_api_secrets_repository.create_secret(
user_id=user_id,
brain_id=brain_id,
secret_name=secret_name,

View File

@ -0,0 +1 @@
from .chat_routes import chat_router

View File

@ -6,7 +6,7 @@ from modules.brain.service.brain_authorization_service import (
validate_brain_authorization,
)
from modules.brain.service.brain_service import BrainService
from routes.chat.interface import ChatInterface
from modules.chat.controller.chat.interface import ChatInterface
models_supporting_function_calls = [
"gpt-4",

View File

@ -1,5 +1,5 @@
from llm.qa_headless import HeadlessQA
from routes.chat.interface import ChatInterface
from modules.chat.controller.chat.interface import ChatInterface
class BrainlessChat(ChatInterface):

View File

@ -3,7 +3,6 @@ from uuid import UUID
from fastapi import HTTPException
from models import UserUsage
from models.databases.supabase.supabase import SupabaseDB
from modules.user.entity.user_identity import UserIdentity
@ -22,19 +21,6 @@ class NullableUUID(UUID):
return None
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
try:
supabase_db.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
supabase_db.delete_chat(chat_id)
except Exception as e:
print(e)
pass
def check_user_requests_limit(
user: UserIdentity,
):

View File

@ -7,36 +7,27 @@ from fastapi.responses import StreamingResponse
from llm.qa_base import QABaseBrainPicking
from llm.qa_headless import HeadlessQA
from middlewares.auth import AuthBearer, get_current_user
from models import Chat, ChatQuestion, UserUsage, get_supabase_db
from models.databases.supabase.chats import QuestionAndAnswer
from models.user_usage import UserUsage
from modules.brain.service.brain_service import BrainService
from modules.notification.service.notification_service import NotificationService
from modules.user.entity.user_identity import UserIdentity
from repository.chat import (
from modules.chat.controller.chat.factory import get_chat_strategy
from modules.chat.controller.chat.utils import NullableUUID, check_user_requests_limit
from modules.chat.dto.chats import ChatItem, ChatQuestion
from modules.chat.dto.inputs import (
ChatUpdatableProperties,
CreateChatProperties,
GetChatHistoryOutput,
create_chat,
get_chat_by_id,
get_user_chats,
update_chat,
)
from repository.chat.add_question_and_answer import add_question_and_answer
from repository.chat.get_chat_history_with_notifications import (
ChatItem,
get_chat_history_with_notifications,
)
from routes.chat.factory import get_chat_strategy
from routes.chat.utils import (
NullableUUID,
check_user_requests_limit,
delete_chat_from_db,
QuestionAndAnswer,
)
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.entity.chat import Chat
from modules.chat.service.chat_service import ChatService
from modules.notification.service.notification_service import NotificationService
from modules.user.entity.user_identity import UserIdentity
chat_router = APIRouter()
notification_service = NotificationService()
brain_service = BrainService()
chat_service = ChatService()
@chat_router.get("/chat/healthz", tags=["Health"])
@ -56,7 +47,7 @@ async def get_chats(current_user: UserIdentity = Depends(get_current_user)):
This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects
containing the chat ID and chat name for each chat.
"""
chats = get_user_chats(str(current_user.id))
chats = chat_service.get_user_chats(str(current_user.id))
return {"chats": chats}
@ -68,10 +59,9 @@ async def delete_chat(chat_id: UUID):
"""
Delete a specific chat by chat ID.
"""
supabase_db = get_supabase_db()
notification_service.remove_chat_notifications(chat_id)
delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id)
chat_service.delete_chat_from_db(chat_id)
return {"message": f"{chat_id} has been deleted."}
@ -83,18 +73,20 @@ async def update_chat_metadata_handler(
chat_data: ChatUpdatableProperties,
chat_id: UUID,
current_user: UserIdentity = Depends(get_current_user),
) -> Chat:
):
"""
Update chat attributes
"""
chat = get_chat_by_id(chat_id) # pyright: ignore reportPrivateUsage=none
chat = chat_service.get_chat_by_id(
chat_id # pyright: ignore reportPrivateUsage=none
)
if str(current_user.id) != chat.user_id:
raise HTTPException(
status_code=403, # pyright: ignore reportPrivateUsage=none
detail="You should be the owner of the chat to update it.", # pyright: ignore reportPrivateUsage=none
)
return update_chat(chat_id=chat_id, chat_data=chat_data)
return chat_service.update_chat(chat_id=chat_id, chat_data=chat_data)
# create new chat
@ -107,7 +99,7 @@ async def create_chat_handler(
Create a new chat with initial chat messages.
"""
return create_chat(user_id=current_user.id, chat_data=chat_data)
return chat_service.create_chat(user_id=current_user.id, chat_data=chat_data)
# add new question to chat
@ -271,7 +263,7 @@ async def get_chat_history_handler(
chat_id: UUID,
) -> List[ChatItem]:
# TODO: RBAC with current_user
return get_chat_history_with_notifications(chat_id)
return chat_service.get_chat_history_with_notifications(chat_id)
@chat_router.post(
@ -286,4 +278,4 @@ async def add_question_and_answer_handler(
"""
Add a new question and anwser to the chat.
"""
return add_question_and_answer(chat_id, question_and_answer)
return chat_service.add_question_and_answer(chat_id, question_and_answer)

View File

@ -1,6 +1,9 @@
from typing import List, Optional, Tuple
from enum import Enum
from typing import List, Optional, Tuple, Union
from uuid import UUID
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.notification.entity.notification import Notification
from pydantic import BaseModel
@ -23,3 +26,13 @@ class ChatQuestion(BaseModel):
max_tokens: Optional[int]
brain_id: Optional[UUID]
prompt_id: Optional[UUID]
class ChatItemType(Enum):
MESSAGE = "MESSAGE"
NOTIFICATION = "NOTIFICATION"
class ChatItem(BaseModel):
item_type: ChatItemType
body: Union[GetChatHistoryOutput, Notification]

View File

@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class CreateChatHistory(BaseModel):
chat_id: UUID
user_message: str
assistant: str
prompt_id: Optional[UUID]
brain_id: Optional[UUID]
class QuestionAndAnswer(BaseModel):
question: str
answer: str
@dataclass
class CreateChatProperties:
name: str
def __init__(self, name: str):
self.name = name
@dataclass
class ChatUpdatableProperties:
chat_name: Optional[str] = None
def __init__(self, chat_name: Optional[str]):
self.chat_name = chat_name

View File

@ -0,0 +1,21 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class GetChatHistoryOutput(BaseModel):
chat_id: UUID
message_id: UUID
user_message: str
assistant: str
message_time: str
prompt_title: Optional[str] | None
brain_name: Optional[str] | None
def dict(self, *args, **kwargs):
chat_history = super().dict(*args, **kwargs)
chat_history["chat_id"] = str(chat_history.get("chat_id"))
chat_history["message_id"] = str(chat_history.get("message_id"))
return chat_history

View File

@ -1,26 +1,11 @@
from typing import Optional
from uuid import UUID
from models.chat import Chat
from models.databases.repository import Repository
from pydantic import BaseModel
from models.settings import get_supabase_client
from modules.chat.entity.chat import Chat
from modules.chat.repository.chats_interface import ChatsInterface
class CreateChatHistory(BaseModel):
chat_id: UUID
user_message: str
assistant: str
prompt_id: Optional[UUID]
brain_id: Optional[UUID]
class QuestionAndAnswer(BaseModel):
question: str
answer: str
class Chats(Repository):
def __init__(self, supabase_client):
class Chats(ChatsInterface):
def __init__(self):
supabase_client = get_supabase_client()
self.db = supabase_client
def create_chat(self, new_chat):
@ -36,9 +21,7 @@ class Chats(Repository):
)
return response
def add_question_and_answer(
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
) -> Optional[Chat]:
def add_question_and_answer(self, chat_id, question_and_answer):
response = (
self.db.table("chat_history")
.insert(
@ -66,7 +49,7 @@ class Chats(Repository):
return reponse
def get_user_chats(self, user_id: str):
def get_user_chats(self, user_id):
response = (
self.db.from_("chats")
.select("chat_id,user_id,creation_time,chat_name")
@ -76,7 +59,7 @@ class Chats(Repository):
)
return response
def update_chat_history(self, chat_history: CreateChatHistory):
def update_chat_history(self, chat_history):
response = (
self.db.table("chat_history")
.insert(
@ -114,15 +97,6 @@ class Chats(Repository):
return response
def get_chat_details(self, chat_id):
response = (
self.db.from_("chats")
.select("*")
.filter("chat_id", "eq", chat_id)
.execute()
)
return response
def delete_chat(self, chat_id):
self.db.table("chats").delete().match({"chat_id": chat_id}).execute()

View File

@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from typing import Optional
from uuid import UUID
from modules.chat.dto.inputs import CreateChatHistory, QuestionAndAnswer
from modules.chat.entity.chat import Chat
class ChatsInterface(ABC):
@abstractmethod
def create_chat(self, new_chat):
"""
Insert a chat entry in "chats" db
"""
pass
@abstractmethod
def get_chat_by_id(self, chat_id: str):
"""
Get chat details by chat_id
"""
pass
@abstractmethod
def add_question_and_answer(
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
) -> Optional[Chat]:
"""
Add a question and answer to the chat history
"""
pass
@abstractmethod
def get_chat_history(self, chat_id: str):
"""
Get chat history by chat_id
"""
pass
@abstractmethod
def get_user_chats(self, user_id: str):
"""
Get all chats for a user
"""
pass
@abstractmethod
def update_chat_history(self, chat_history: CreateChatHistory):
"""
Update chat history
"""
pass
@abstractmethod
def update_chat(self, chat_id, updates):
"""
Update chat details
"""
pass
@abstractmethod
def update_message_by_id(self, message_id, updates):
"""
Update message details
"""
pass
@abstractmethod
def delete_chat(self, chat_id):
"""
Delete chat
"""
pass
@abstractmethod
def delete_chat_history(self, chat_id):
"""
Delete chat history
"""
pass

View File

@ -0,0 +1,169 @@
from typing import List, Optional
from uuid import UUID
from fastapi import HTTPException
from logger import get_logger
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatItem
from modules.chat.dto.inputs import (
ChatUpdatableProperties,
CreateChatHistory,
CreateChatProperties,
QuestionAndAnswer,
)
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.entity.chat import Chat, ChatHistory
from modules.chat.repository.chats import Chats
from modules.chat.repository.chats_interface import ChatsInterface
from modules.chat.service.utils import merge_chat_history_and_notifications
from modules.notification.service.notification_service import NotificationService
from modules.prompt.service.prompt_service import PromptService
logger = get_logger(__name__)
prompt_service = PromptService()
brain_service = BrainService()
notification_service = NotificationService()
class ChatService:
repository: ChatsInterface
def __init__(self):
self.repository = Chats()
def create_chat(self, user_id: UUID, chat_data: CreateChatProperties) -> Chat:
# Chat is created upon the user's first question asked
logger.info(f"New chat entry in chats table for user {user_id}")
# Insert a new row into the chats table
new_chat = {
"user_id": str(user_id),
"chat_name": chat_data.name,
}
insert_response = self.repository.create_chat(new_chat)
logger.info(f"Insert response {insert_response.data}")
return insert_response.data[0]
def add_question_and_answer(
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
) -> Optional[Chat]:
return self.repository.add_question_and_answer(chat_id, question_and_answer)
def get_chat_by_id(self, chat_id: str) -> Chat:
response = self.repository.get_chat_by_id(chat_id)
return Chat(response.data[0])
def get_chat_history(self, chat_id: str) -> List[GetChatHistoryOutput]:
history: List[dict] = self.repository.get_chat_history(chat_id).data
if history is None:
return []
else:
enriched_history: List[GetChatHistoryOutput] = []
for message in history:
message = ChatHistory(message)
brain = None
if message.brain_id:
brain = brain_service.get_brain_by_id(message.brain_id)
prompt = None
if message.prompt_id:
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
enriched_history.append(
GetChatHistoryOutput(
chat_id=(UUID(message.chat_id)),
message_id=(UUID(message.message_id)),
user_message=message.user_message,
assistant=message.assistant,
message_time=message.message_time,
brain_name=brain.name if brain else None,
prompt_title=prompt.title if prompt else None,
)
)
return enriched_history
def get_chat_history_with_notifications(
self,
chat_id: UUID,
) -> List[ChatItem]:
chat_history = self.get_chat_history(str(chat_id))
chat_notifications = notification_service.get_chat_notifications(chat_id)
return merge_chat_history_and_notifications(chat_history, chat_notifications)
def get_user_chats(self, user_id: str) -> List[Chat]:
response = self.repository.get_user_chats(user_id)
chats = [Chat(chat_dict) for chat_dict in response.data]
return chats
def update_chat_history(self, chat_history: CreateChatHistory) -> ChatHistory:
response: List[ChatHistory] = (
self.repository.update_chat_history(chat_history)
).data
if len(response) == 0:
raise HTTPException(
status_code=500,
detail="An exception occurred while updating chat history.",
)
return ChatHistory(response[0]) # pyright: ignore reportPrivateUsage=none
def update_chat(self, chat_id, chat_data: ChatUpdatableProperties) -> Chat:
if not chat_id:
logger.error("No chat_id provided")
return # pyright: ignore reportPrivateUsage=none
updates = {}
if chat_data.chat_name is not None:
updates["chat_name"] = chat_data.chat_name
updated_chat = None
if updates:
updated_chat = (self.repository.update_chat(chat_id, updates)).data[0]
logger.info(f"Chat {chat_id} updated")
else:
logger.info(f"No updates to apply for chat {chat_id}")
return updated_chat # pyright: ignore reportPrivateUsage=none
def update_message_by_id(
self,
message_id: str,
user_message: str = None, # pyright: ignore reportPrivateUsage=none
assistant: str = None, # pyright: ignore reportPrivateUsage=none
) -> ChatHistory:
if not message_id:
logger.error("No message_id provided")
return # pyright: ignore reportPrivateUsage=none
updates = {}
if user_message is not None:
updates["user_message"] = user_message
if assistant is not None:
updates["assistant"] = assistant
updated_message = None
if updates:
updated_message = (self.repository.update_message_by_id(message_id, updates)).data[ # type: ignore
0
]
logger.info(f"Message {message_id} updated")
else:
logger.info(f"No updates to apply for message {message_id}")
return ChatHistory(updated_message) # pyright: ignore reportPrivateUsage=none
def delete_chat_from_db(self, chat_id):
try:
self.repository.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
self.repository.delete_chat(chat_id)
except Exception as e:
print(e)
pass

View File

@ -1,28 +1,21 @@
from enum import Enum
from typing import List, Union
from uuid import UUID
from typing import List
from logger import get_logger
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatItem, ChatItemType
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.notification.entity.notification import Notification
from modules.notification.service.notification_service import NotificationService
from modules.prompt.service.prompt_service import PromptService
from packages.utils import parse_message_time
from pydantic import BaseModel
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
class ChatItemType(Enum):
MESSAGE = "MESSAGE"
NOTIFICATION = "NOTIFICATION"
class ChatItem(BaseModel):
item_type: ChatItemType
body: Union[GetChatHistoryOutput, Notification]
logger = get_logger(__name__)
prompt_service = PromptService()
brain_service = BrainService()
notification_service = NotificationService()
# Move these methods to ChatService in chat module
def merge_chat_history_and_notifications(
chat_history: List[GetChatHistoryOutput], notifications: List[Notification]
) -> List[ChatItem]:
@ -46,11 +39,3 @@ def merge_chat_history_and_notifications(
transformed_data.append(transformed_item)
return transformed_data
def get_chat_history_with_notifications(
chat_id: UUID,
) -> List[ChatItem]:
chat_history = get_chat_history(str(chat_id))
chat_notifications = notification_service.get_chat_notifications(chat_id)
return merge_chat_history_and_notifications(chat_history, chat_notifications)

View File

@ -1,3 +0,0 @@
from modules.brain.service.brain_user_service import BrainUserService
brain_user_service = BrainUserService()

View File

@ -8,6 +8,7 @@ logger = get_logger(__name__)
def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
# TODO: Move to AnswerGenerator service
supabase_client = get_supabase_client()
embeddings = get_embeddings()

View File

@ -1,9 +0,0 @@
from .create_chat import create_chat, CreateChatProperties
from .update_chat import update_chat, ChatUpdatableProperties
from .get_user_chats import get_user_chats
from .get_chat_by_id import get_chat_by_id
from .get_chat_history import GetChatHistoryOutput, get_chat_history
from .update_chat_history import update_chat_history
from .update_message_by_id import update_message_by_id
from .format_chat_history import format_history_to_openai_mesages
from .format_chat_history import format_chat_history, format_history_to_openai_mesages

View File

@ -1,13 +0,0 @@
from typing import Optional
from uuid import UUID
from models import Chat, get_supabase_db
from models.databases.supabase.chats import QuestionAndAnswer
def add_question_and_answer(
chat_id: UUID, question_and_answer: QuestionAndAnswer
) -> Optional[Chat]:
supabase_db = get_supabase_db()
return supabase_db.add_question_and_answer(chat_id, question_and_answer)

View File

@ -1,32 +0,0 @@
from dataclasses import dataclass
from uuid import UUID
from logger import get_logger
from models import Chat, get_supabase_db
logger = get_logger(__name__)
@dataclass
class CreateChatProperties:
name: str
def __init__(self, name: str):
self.name = name
def create_chat(user_id: UUID, chat_data: CreateChatProperties) -> Chat:
supabase_db = get_supabase_db()
# Chat is created upon the user's first question asked
logger.info(f"New chat entry in chats table for user {user_id}")
# Insert a new row into the chats table
new_chat = {
"user_id": str(user_id),
"chat_name": chat_data.name,
}
insert_response = supabase_db.create_chat(new_chat)
logger.info(f"Insert response {insert_response.data}")
return insert_response.data[0]

View File

@ -1,8 +0,0 @@
from models import Chat, get_supabase_db
def get_chat_by_id(chat_id: str) -> Chat:
supabase_db = get_supabase_db()
response = supabase_db.get_chat_by_id(chat_id)
return Chat(response.data[0])

View File

@ -1,59 +0,0 @@
from typing import List, Optional
from uuid import UUID
from models import ChatHistory, get_supabase_db
from modules.brain.service.brain_service import BrainService
from modules.prompt.service.prompt_service import PromptService
from pydantic import BaseModel
prompt_service = PromptService()
brain_service = BrainService()
class GetChatHistoryOutput(BaseModel):
chat_id: UUID
message_id: UUID
user_message: str
assistant: str
message_time: str
prompt_title: Optional[str] | None
brain_name: Optional[str] | None
def dict(self, *args, **kwargs):
chat_history = super().dict(*args, **kwargs)
chat_history["chat_id"] = str(chat_history.get("chat_id"))
chat_history["message_id"] = str(chat_history.get("message_id"))
return chat_history
def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]:
supabase_db = get_supabase_db()
history: List[dict] = supabase_db.get_chat_history(chat_id).data
if history is None:
return []
else:
enriched_history: List[GetChatHistoryOutput] = []
for message in history:
message = ChatHistory(message)
brain = None
if message.brain_id:
brain = brain_service.get_brain_by_id(message.brain_id)
prompt = None
if message.prompt_id:
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
enriched_history.append(
GetChatHistoryOutput(
chat_id=(UUID(message.chat_id)),
message_id=(UUID(message.message_id)),
user_message=message.user_message,
assistant=message.assistant,
message_time=message.message_time,
brain_name=brain.name if brain else None,
prompt_title=prompt.title if prompt else None,
)
)
return enriched_history

View File

@ -1,10 +0,0 @@
from typing import List
from models import Chat, get_supabase_db
def get_user_chats(user_id: str) -> List[Chat]:
supabase_db = get_supabase_db()
response = supabase_db.get_user_chats(user_id)
chats = [Chat(chat_dict) for chat_dict in response.data]
return chats

View File

@ -1,37 +0,0 @@
from dataclasses import dataclass
from typing import Optional
from logger import get_logger
from models import Chat, get_supabase_db
logger = get_logger(__name__)
@dataclass
class ChatUpdatableProperties:
chat_name: Optional[str] = None
def __init__(self, chat_name: Optional[str]):
self.chat_name = chat_name
def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
supabase_db = get_supabase_db()
if not chat_id:
logger.error("No chat_id provided")
return # pyright: ignore reportPrivateUsage=none
updates = {}
if chat_data.chat_name is not None:
updates["chat_name"] = chat_data.chat_name
updated_chat = None
if updates:
updated_chat = (supabase_db.update_chat(chat_id, updates)).data[0]
logger.info(f"Chat {chat_id} updated")
else:
logger.info(f"No updates to apply for chat {chat_id}")
return updated_chat # pyright: ignore reportPrivateUsage=none

View File

@ -1,15 +0,0 @@
from typing import List
from fastapi import HTTPException
from models.databases.supabase.chats import CreateChatHistory
from models import ChatHistory, get_supabase_db
def update_chat_history(chat_history: CreateChatHistory) -> ChatHistory:
supabase_db = get_supabase_db()
response: List[ChatHistory] = (supabase_db.update_chat_history(chat_history)).data
if len(response) == 0:
raise HTTPException(
status_code=500, detail="An exception occurred while updating chat history."
)
return ChatHistory(response[0]) # pyright: ignore reportPrivateUsage=none

View File

@ -1,35 +0,0 @@
from logger import get_logger
from models import ChatHistory, get_supabase_db
logger = get_logger(__name__)
def update_message_by_id(
message_id: str,
user_message: str = None, # pyright: ignore reportPrivateUsage=none
assistant: str = None, # pyright: ignore reportPrivateUsage=none
) -> ChatHistory:
supabase_db = get_supabase_db()
if not message_id:
logger.error("No message_id provided")
return # pyright: ignore reportPrivateUsage=none
updates = {}
if user_message is not None:
updates["user_message"] = user_message
if assistant is not None:
updates["assistant"] = assistant
updated_message = None
if updates:
updated_message = (supabase_db.update_message_by_id(message_id, updates)).data[ # type: ignore
0
]
logger.info(f"Message {message_id} updated")
else:
logger.info(f"No updates to apply for message {message_id}")
return ChatHistory(updated_message) # pyright: ignore reportPrivateUsage=none

View File

@ -1,2 +0,0 @@
from .create_secret import create_secret
from .read_secret import read_secret

View File

@ -1,21 +0,0 @@
from uuid import UUID
from models import get_supabase_client
from repository.external_api_secret.utils import build_secret_unique_name
def create_secret(
user_id: UUID, brain_id: UUID, secret_name: str, secret_value
) -> UUID | None:
supabase_client = get_supabase_client()
response = supabase_client.rpc(
"insert_secret",
{
"name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
"secret": secret_value,
},
).execute()
return response.data

View File

@ -1,23 +0,0 @@
from uuid import UUID
from models import get_supabase_client
from repository.external_api_secret.utils import build_secret_unique_name
def read_secret(
user_id: UUID,
brain_id: UUID,
secret_name: str,
) -> UUID | None:
supabase_client = get_supabase_client()
response = supabase_client.rpc(
"read_secret",
{
"secret_name": build_secret_unique_name(
user_id=user_id, brain_id=brain_id, secret_name=secret_name
),
},
).execute()
return response.data

View File

@ -1,26 +0,0 @@
from uuid import UUID
from modules.brain.service.brain_service import BrainService
from repository.external_api_secret.create_secret import create_secret
brain_service = BrainService()
def update_secret_value(
user_id: UUID,
brain_id: UUID,
secret_name: str,
secret_value: str,
) -> None:
"""Update an existing secret."""
brain_service.delete_secret(
user_id=user_id,
brain_id=brain_id,
secret_name=secret_name,
)
create_secret(
user_id=user_id,
brain_id=brain_id,
secret_name=secret_name,
secret_value=secret_value,
)

View File

@ -1,5 +0,0 @@
from uuid import UUID
def build_secret_unique_name(user_id: UUID, brain_id: UUID, secret_name: str):
return f"{user_id}-{brain_id}-{secret_name}"

View File

@ -17,7 +17,6 @@ from modules.brain.service.brain_user_service import BrainUserService
from modules.prompt.service.prompt_service import PromptService
from modules.user.entity.user_identity import UserIdentity
from repository.brain import get_question_context_from_brain
from repository.external_api_secret.update_secret_value import update_secret_value
logger = get_logger(__name__)
brain_router = APIRouter()
@ -197,7 +196,7 @@ async def update_existing_brain_secrets(
detail=f"Secret {key} is not a valid secret.",
)
if value:
update_secret_value(
brain_service.update_secret_value(
user_id=current_user.id,
brain_id=brain_id,
secret_name=key,
@ -226,6 +225,7 @@ async def set_brain_as_default(
tags=["Brain"],
)
async def get_question_context_for_brain(brain_id: UUID, request: BrainQuestionRequest):
# TODO: Move this endpoint to AnswerGenerator service
"""Retrieve the question context from a specific brain."""
context = get_question_context_from_brain(brain_id, request.question)
return {"context": context}

View File

@ -23,7 +23,6 @@ from repository.brain_subscription import (
SubscriptionInvitationService,
resend_invitation_email,
)
from repository.external_api_secret.create_secret import create_secret
from routes.headers.get_origin_header import get_origin_header
subscription_router = APIRouter()
@ -420,7 +419,7 @@ async def subscribe_to_brain_handler(
)
for secret in brain_secrets:
create_secret(
brain_service.external_api_secrets_repository.create_secret(
user_id=current_user.id,
brain_id=brain_id,
secret_name=secret.name,