mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-26 03:15:19 +03:00
refactor: chat for multibrains (#1812)
# Description - Chat Module - External Api Secrets Interface, exposed through brain service
This commit is contained in:
parent
8ddb7708fd
commit
436e49a5e7
@ -12,8 +12,8 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from modules.chat.controller import chat_router
|
||||
from modules.misc.controller import misc_router
|
||||
from routes.chat_routes import chat_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -3,24 +3,25 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from logger import get_logger
|
||||
from litellm import completion
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from llm.utils.call_brain_api import call_brain_api
|
||||
from llm.utils.get_api_brain_definition_as_json_schema import (
|
||||
get_api_brain_definition_as_json_schema,
|
||||
)
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
|
||||
from repository.chat.update_chat_history import update_chat_history
|
||||
from repository.chat.update_message_by_id import update_message_by_id
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class APIBrainQA(
|
||||
QABaseBrainPicking,
|
||||
):
|
||||
@ -54,7 +55,7 @@ class APIBrainQA(
|
||||
messages,
|
||||
functions,
|
||||
brain_id: UUID,
|
||||
recursive_count = 0,
|
||||
recursive_count=0,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
yield "🧠<Deciding what to do>🧠"
|
||||
@ -80,13 +81,16 @@ class APIBrainQA(
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
if finish_reason == "stop":
|
||||
break
|
||||
if "function_call" in chunk.choices[0].delta and chunk.choices[0].delta["function_call"]:
|
||||
if (
|
||||
"function_call" in chunk.choices[0].delta
|
||||
and chunk.choices[0].delta["function_call"]
|
||||
):
|
||||
if chunk.choices[0].delta["function_call"].name:
|
||||
function_call["name"] = chunk.choices[0].delta["function_call"].name
|
||||
if chunk.choices[0].delta["function_call"].arguments:
|
||||
function_call["arguments"] += chunk.choices[0].delta[
|
||||
"function_call"
|
||||
].arguments
|
||||
function_call["arguments"] += (
|
||||
chunk.choices[0].delta["function_call"].arguments
|
||||
)
|
||||
|
||||
elif finish_reason == "function_call":
|
||||
try:
|
||||
@ -154,7 +158,7 @@ class APIBrainQA(
|
||||
|
||||
messages = [{"role": "system", "content": prompt_content}]
|
||||
|
||||
history = get_chat_history(self.chat_id)
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
@ -165,7 +169,7 @@ class APIBrainQA(
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
streamed_chat_history = update_chat_history(
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
@ -203,7 +207,7 @@ class APIBrainQA(
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<") and not token.endswith(">🧠")
|
||||
]
|
||||
update_message_by_id(
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
|
@ -15,21 +15,17 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from llm.utils.format_chat_history import format_chat_history
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from pydantic import BaseModel
|
||||
from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
format_chat_history,
|
||||
get_chat_history,
|
||||
update_chat_history,
|
||||
update_message_by_id,
|
||||
)
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
@ -40,6 +36,7 @@ QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you
|
||||
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
class QABaseBrainPicking(BaseModel):
|
||||
@ -155,7 +152,6 @@ class QABaseBrainPicking(BaseModel):
|
||||
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
|
||||
api_base = self.brain_settings.ollama_api_base_url
|
||||
|
||||
|
||||
return ChatLiteLLM(
|
||||
temperature=temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
@ -163,7 +159,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
streaming=streaming,
|
||||
verbose=False,
|
||||
callbacks=callbacks,
|
||||
api_base= api_base
|
||||
api_base=api_base,
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def _create_prompt_template(self):
|
||||
@ -192,7 +188,9 @@ class QABaseBrainPicking(BaseModel):
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
) -> GetChatHistoryOutput:
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model,
|
||||
streaming=False,
|
||||
@ -230,7 +228,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
|
||||
answer = model_response["answer"]
|
||||
|
||||
new_chat = update_chat_history(
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
@ -264,7 +262,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
) -> AsyncIterable:
|
||||
history = get_chat_history(self.chat_id)
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
self.callbacks = [callback]
|
||||
|
||||
@ -323,7 +321,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
if question.brain_id:
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
streamed_chat_history = update_chat_history(
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
@ -386,7 +384,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
assistant += sources_string
|
||||
|
||||
try:
|
||||
update_message_by_id(
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
|
@ -7,27 +7,25 @@ from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from logger import get_logger
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
from llm.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
get_chat_history,
|
||||
update_chat_history,
|
||||
update_message_by_id,
|
||||
)
|
||||
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
class HeadlessQA(BaseModel):
|
||||
@ -104,7 +102,10 @@ class HeadlessQA(BaseModel):
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
) -> GetChatHistoryOutput:
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
# Move format_chat_history to chat service ?
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
@ -120,7 +121,7 @@ class HeadlessQA(BaseModel):
|
||||
model_prediction = answering_llm.predict_messages(messages)
|
||||
answer = model_prediction.content
|
||||
|
||||
new_chat = update_chat_history(
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
@ -152,7 +153,9 @@ class HeadlessQA(BaseModel):
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
self.callbacks = [callback]
|
||||
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
@ -186,7 +189,7 @@ class HeadlessQA(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
streamed_chat_history = update_chat_history(
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
@ -221,7 +224,7 @@ class HeadlessQA(BaseModel):
|
||||
await run
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
update_message_by_id(
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
|
@ -5,10 +5,12 @@ from llm.utils.extract_api_brain_definition_values_from_llm_output import (
|
||||
extract_api_brain_definition_values_from_llm_output,
|
||||
)
|
||||
from llm.utils.make_api_request import get_api_call_response_as_text
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from repository.api_brain_definition.get_api_brain_definition import (
|
||||
get_api_brain_definition,
|
||||
)
|
||||
from repository.external_api_secret.read_secret import read_secret
|
||||
|
||||
brain_service = BrainService()
|
||||
|
||||
|
||||
def call_brain_api(brain_id: UUID, user_id: UUID, arguments: dict) -> str:
|
||||
@ -31,7 +33,7 @@ def call_brain_api(brain_id: UUID, user_id: UUID, arguments: dict) -> str:
|
||||
secrets_values = {}
|
||||
|
||||
for secret in secrets:
|
||||
secret_value = read_secret(
|
||||
secret_value = brain_service.external_api_secrets_repository.read_secret(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret.name
|
||||
)
|
||||
secrets_values[secret.name] = secret_value
|
||||
|
@ -1,6 +1,4 @@
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import needed here when running main.py to debug backend
|
||||
@ -14,6 +12,7 @@ from fastapi.responses import JSONResponse
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from modules.api_key.controller import api_key_router
|
||||
from modules.chat.controller import chat_router
|
||||
from modules.contact_support.controller import contact_router
|
||||
from modules.knowledge.controller import knowledge_router
|
||||
from modules.misc.controller import misc_router
|
||||
@ -22,12 +21,10 @@ from modules.onboarding.controller import onboarding_router
|
||||
from modules.prompt.controller import prompt_router
|
||||
from modules.upload.controller import upload_router
|
||||
from modules.user.controller import user_router
|
||||
from packages.utils import handle_request_validation_error
|
||||
from routes.brain_routes import brain_router
|
||||
from routes.chat_routes import chat_router
|
||||
from routes.crawl_routes import crawl_router
|
||||
from routes.subscription_routes import subscription_router
|
||||
from logger import get_logger
|
||||
from packages.utils import handle_request_validation_error
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
|
||||
|
@ -1,18 +1,6 @@
|
||||
from .brains_subscription_invitations import BrainSubscription
|
||||
from .chat import Chat, ChatHistory
|
||||
from .chats import ChatMessage, ChatQuestion
|
||||
from .files import File
|
||||
from .settings import (BrainRateLimiting, BrainSettings, ResendSettings,
|
||||
get_documents_vector_store, get_embeddings,
|
||||
get_supabase_client, get_supabase_db)
|
||||
from .user_usage import UserUsage
|
||||
|
||||
# TODO uncomment the below import when start using SQLalchemy
|
||||
# from .sqlalchemy_repository import (
|
||||
# User,
|
||||
# Brain,
|
||||
# BrainUser,
|
||||
# BrainVector,
|
||||
# BrainSubscriptionInvitation,
|
||||
# ApiKey
|
||||
# )
|
||||
|
@ -54,50 +54,6 @@ class Repository(ABC):
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_chat(self, new_chat):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_by_id(self, chat_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_history(self, chat_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_chats(self, user_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_chat_history(self, chat_id: str, user_message: str, assistant: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_chat(self, chat_id: UUID, updates):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_question_and_answer(self, chat_id: str, question_and_answer):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_message_by_id(self, message_id: UUID, updates):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_details(self, chat_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_chat(self, chat_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_chat_history(self, chat_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors_by_file_name(self, file_name: str):
|
||||
pass
|
||||
|
@ -1,7 +1,6 @@
|
||||
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions
|
||||
from models.databases.supabase.brains_subscription_invitations import \
|
||||
BrainSubscription
|
||||
from models.databases.supabase.chats import Chats
|
||||
from models.databases.supabase.files import File
|
||||
from models.databases.supabase.user_usage import UserUsage
|
||||
from models.databases.supabase.vectors import Vector
|
||||
|
@ -2,7 +2,6 @@ from logger import get_logger
|
||||
from models.databases.supabase import (
|
||||
ApiBrainDefinitions,
|
||||
BrainSubscription,
|
||||
Chats,
|
||||
File,
|
||||
UserUsage,
|
||||
Vector,
|
||||
@ -15,7 +14,6 @@ class SupabaseDB(
|
||||
UserUsage,
|
||||
File,
|
||||
BrainSubscription,
|
||||
Chats,
|
||||
Vector,
|
||||
ApiBrainDefinitions,
|
||||
):
|
||||
@ -24,6 +22,5 @@ class SupabaseDB(
|
||||
UserUsage.__init__(self, supabase_client)
|
||||
File.__init__(self, supabase_client)
|
||||
BrainSubscription.__init__(self, supabase_client)
|
||||
Chats.__init__(self, supabase_client)
|
||||
Vector.__init__(self, supabase_client)
|
||||
ApiBrainDefinitions.__init__(self, supabase_client)
|
||||
|
@ -1,14 +1,14 @@
|
||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from logger import get_logger
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
from pydantic import BaseSettings
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import SupabaseVectorStore
|
||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BrainRateLimiting(BaseSettings):
|
||||
max_brain_per_user: int = 5
|
||||
|
||||
|
@ -5,7 +5,6 @@ from models.settings import get_supabase_client
|
||||
from modules.brain.dto.inputs import BrainUpdatableProperties
|
||||
from modules.brain.entity.brain_entity import BrainEntity, PublicBrain
|
||||
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
|
||||
from repository.external_api_secret.utils import build_secret_unique_name
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -97,15 +96,3 @@ class Brains(BrainsInterface):
|
||||
return None
|
||||
|
||||
return BrainEntity(**response[0])
|
||||
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
response = self.db.rpc(
|
||||
"delete_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
||||
|
60
backend/modules/brain/repository/external_api_secrets.py
Normal file
60
backend/modules/brain/repository/external_api_secrets.py
Normal file
@ -0,0 +1,60 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.settings import get_supabase_client
|
||||
from modules.brain.repository.interfaces.external_api_secrets_interface import (
|
||||
ExternalApiSecretsInterface,
|
||||
)
|
||||
|
||||
|
||||
def build_secret_unique_name(user_id: UUID, brain_id: UUID, secret_name: str):
|
||||
return f"{user_id}-{brain_id}-{secret_name}"
|
||||
|
||||
|
||||
class ExternalApiSecrets(ExternalApiSecretsInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
def create_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
|
||||
) -> UUID | None:
|
||||
response = self.db.rpc(
|
||||
"insert_secret",
|
||||
{
|
||||
"name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
"secret": secret_value,
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
||||
|
||||
def read_secret(
|
||||
self,
|
||||
user_id: UUID,
|
||||
brain_id: UUID,
|
||||
secret_name: str,
|
||||
) -> UUID | None:
|
||||
response = self.db.rpc(
|
||||
"read_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
||||
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
response = self.db.rpc(
|
||||
"delete_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
@ -56,10 +56,3 @@ class BrainsInterface(ABC):
|
||||
Get a brain by id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
"""
|
||||
Delete a secret from a brain
|
||||
"""
|
||||
pass
|
||||
|
@ -1,41 +1,29 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
# TODO: Replace BrainsVectors with KnowledgeVectors interface instead
|
||||
class BrainsVectorsInterface(ABC):
|
||||
class ExternalApiSecretsInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_brain_vector(self, brain_id, vector_id, file_sha1):
|
||||
def create_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Create a brain vector
|
||||
Create a new secret for the API Request in given brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vector_ids_from_file_sha1(self, file_sha1: str):
|
||||
def read_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Get vector ids from file sha1
|
||||
Read a secret for the API Request in given brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_brain_vector_ids(self, brain_id) -> List[UUID]:
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
"""
|
||||
Get brain vector ids
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_file_from_brain(self, brain_id, file_name: str):
|
||||
"""
|
||||
Delete file from brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_brain_vector(self, brain_id: str):
|
||||
"""
|
||||
Delete brain vector
|
||||
Delete a secret from a brain
|
||||
"""
|
||||
pass
|
||||
|
@ -7,6 +7,7 @@ from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrai
|
||||
from modules.brain.repository.brains import Brains
|
||||
from modules.brain.repository.brains_users import BrainsUsers
|
||||
from modules.brain.repository.brains_vectors import BrainsVectors
|
||||
from modules.brain.repository.external_api_secrets import ExternalApiSecrets
|
||||
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
|
||||
from modules.brain.repository.interfaces.brains_users_interface import (
|
||||
BrainsUsersInterface,
|
||||
@ -14,6 +15,9 @@ from modules.brain.repository.interfaces.brains_users_interface import (
|
||||
from modules.brain.repository.interfaces.brains_vectors_interface import (
|
||||
BrainsVectorsInterface,
|
||||
)
|
||||
from modules.brain.repository.interfaces.external_api_secrets_interface import (
|
||||
ExternalApiSecretsInterface,
|
||||
)
|
||||
from modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from repository.api_brain_definition.add_api_brain_definition import (
|
||||
add_api_brain_definition,
|
||||
@ -27,7 +31,6 @@ from repository.api_brain_definition.get_api_brain_definition import (
|
||||
from repository.api_brain_definition.update_api_brain_definition import (
|
||||
update_api_brain_definition,
|
||||
)
|
||||
from repository.external_api_secret.create_secret import create_secret
|
||||
|
||||
knowledge_service = KnowledgeService()
|
||||
|
||||
@ -36,11 +39,13 @@ class BrainService:
|
||||
brain_repository: BrainsInterface
|
||||
brain_user_repository: BrainsUsersInterface
|
||||
brain_vector_repository: BrainsVectorsInterface
|
||||
external_api_secrets_repository: ExternalApiSecretsInterface
|
||||
|
||||
def __init__(self):
|
||||
self.brain_repository = Brains()
|
||||
self.brain_user_repository = BrainsUsers()
|
||||
self.brain_vector = BrainsVectors()
|
||||
self.external_api_secrets_repository = ExternalApiSecrets()
|
||||
|
||||
def get_brain_by_id(self, brain_id: UUID):
|
||||
return self.brain_repository.get_brain_by_id(brain_id)
|
||||
@ -75,7 +80,7 @@ class BrainService:
|
||||
secrets_values = brain.brain_secrets_values
|
||||
|
||||
for secret_name in secrets_values:
|
||||
create_secret(
|
||||
self.external_api_secrets_repository.create_secret(
|
||||
user_id=user_id,
|
||||
brain_id=created_brain.brain_id,
|
||||
secret_name=secret_name,
|
||||
@ -96,7 +101,7 @@ class BrainService:
|
||||
brain_users = self.brain_user_repository.get_brain_users(brain_id=brain_id)
|
||||
for user in brain_users:
|
||||
for secret in secrets:
|
||||
self.brain_repository.delete_secret(
|
||||
self.external_api_secrets_repository.delete_secret(
|
||||
user_id=user.user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret.name,
|
||||
@ -217,12 +222,12 @@ class BrainService:
|
||||
secret_value: str,
|
||||
) -> None:
|
||||
"""Update an existing secret."""
|
||||
self.brain_repository.delete_secret(
|
||||
self.external_api_secrets_repository.delete_secret(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret_name,
|
||||
)
|
||||
create_secret(
|
||||
self.external_api_secrets_repository.create_secret(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret_name,
|
||||
|
1
backend/modules/chat/controller/__init__.py
Normal file
1
backend/modules/chat/controller/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .chat_routes import chat_router
|
0
backend/modules/chat/controller/chat/__init_.py
Normal file
0
backend/modules/chat/controller/chat/__init_.py
Normal file
@ -6,7 +6,7 @@ from modules.brain.service.brain_authorization_service import (
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from routes.chat.interface import ChatInterface
|
||||
from modules.chat.controller.chat.interface import ChatInterface
|
||||
|
||||
models_supporting_function_calls = [
|
||||
"gpt-4",
|
@ -1,5 +1,5 @@
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from routes.chat.interface import ChatInterface
|
||||
from modules.chat.controller.chat.interface import ChatInterface
|
||||
|
||||
|
||||
class BrainlessChat(ChatInterface):
|
@ -3,7 +3,6 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models import UserUsage
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
|
||||
@ -22,19 +21,6 @@ class NullableUUID(UUID):
|
||||
return None
|
||||
|
||||
|
||||
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
|
||||
try:
|
||||
supabase_db.delete_chat_history(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
try:
|
||||
supabase_db.delete_chat(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
user: UserIdentity,
|
||||
):
|
@ -7,36 +7,27 @@ from fastapi.responses import StreamingResponse
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from middlewares.auth import AuthBearer, get_current_user
|
||||
from models import Chat, ChatQuestion, UserUsage, get_supabase_db
|
||||
from models.databases.supabase.chats import QuestionAndAnswer
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from repository.chat import (
|
||||
from modules.chat.controller.chat.factory import get_chat_strategy
|
||||
from modules.chat.controller.chat.utils import NullableUUID, check_user_requests_limit
|
||||
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from modules.chat.dto.inputs import (
|
||||
ChatUpdatableProperties,
|
||||
CreateChatProperties,
|
||||
GetChatHistoryOutput,
|
||||
create_chat,
|
||||
get_chat_by_id,
|
||||
get_user_chats,
|
||||
update_chat,
|
||||
)
|
||||
from repository.chat.add_question_and_answer import add_question_and_answer
|
||||
from repository.chat.get_chat_history_with_notifications import (
|
||||
ChatItem,
|
||||
get_chat_history_with_notifications,
|
||||
)
|
||||
from routes.chat.factory import get_chat_strategy
|
||||
from routes.chat.utils import (
|
||||
NullableUUID,
|
||||
check_user_requests_limit,
|
||||
delete_chat_from_db,
|
||||
QuestionAndAnswer,
|
||||
)
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.entity.chat import Chat
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
notification_service = NotificationService()
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||
@ -56,7 +47,7 @@ async def get_chats(current_user: UserIdentity = Depends(get_current_user)):
|
||||
This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects
|
||||
containing the chat ID and chat name for each chat.
|
||||
"""
|
||||
chats = get_user_chats(str(current_user.id))
|
||||
chats = chat_service.get_user_chats(str(current_user.id))
|
||||
return {"chats": chats}
|
||||
|
||||
|
||||
@ -68,10 +59,9 @@ async def delete_chat(chat_id: UUID):
|
||||
"""
|
||||
Delete a specific chat by chat ID.
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
notification_service.remove_chat_notifications(chat_id)
|
||||
|
||||
delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id)
|
||||
chat_service.delete_chat_from_db(chat_id)
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
|
||||
|
||||
@ -83,18 +73,20 @@ async def update_chat_metadata_handler(
|
||||
chat_data: ChatUpdatableProperties,
|
||||
chat_id: UUID,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> Chat:
|
||||
):
|
||||
"""
|
||||
Update chat attributes
|
||||
"""
|
||||
|
||||
chat = get_chat_by_id(chat_id) # pyright: ignore reportPrivateUsage=none
|
||||
chat = chat_service.get_chat_by_id(
|
||||
chat_id # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
if str(current_user.id) != chat.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, # pyright: ignore reportPrivateUsage=none
|
||||
detail="You should be the owner of the chat to update it.", # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
return update_chat(chat_id=chat_id, chat_data=chat_data)
|
||||
return chat_service.update_chat(chat_id=chat_id, chat_data=chat_data)
|
||||
|
||||
|
||||
# create new chat
|
||||
@ -107,7 +99,7 @@ async def create_chat_handler(
|
||||
Create a new chat with initial chat messages.
|
||||
"""
|
||||
|
||||
return create_chat(user_id=current_user.id, chat_data=chat_data)
|
||||
return chat_service.create_chat(user_id=current_user.id, chat_data=chat_data)
|
||||
|
||||
|
||||
# add new question to chat
|
||||
@ -271,7 +263,7 @@ async def get_chat_history_handler(
|
||||
chat_id: UUID,
|
||||
) -> List[ChatItem]:
|
||||
# TODO: RBAC with current_user
|
||||
return get_chat_history_with_notifications(chat_id)
|
||||
return chat_service.get_chat_history_with_notifications(chat_id)
|
||||
|
||||
|
||||
@chat_router.post(
|
||||
@ -286,4 +278,4 @@ async def add_question_and_answer_handler(
|
||||
"""
|
||||
Add a new question and anwser to the chat.
|
||||
"""
|
||||
return add_question_and_answer(chat_id, question_and_answer)
|
||||
return chat_service.add_question_and_answer(chat_id, question_and_answer)
|
@ -1,6 +1,9 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.notification.entity.notification import Notification
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -23,3 +26,13 @@ class ChatQuestion(BaseModel):
|
||||
max_tokens: Optional[int]
|
||||
brain_id: Optional[UUID]
|
||||
prompt_id: Optional[UUID]
|
||||
|
||||
|
||||
class ChatItemType(Enum):
|
||||
MESSAGE = "MESSAGE"
|
||||
NOTIFICATION = "NOTIFICATION"
|
||||
|
||||
|
||||
class ChatItem(BaseModel):
|
||||
item_type: ChatItemType
|
||||
body: Union[GetChatHistoryOutput, Notification]
|
34
backend/modules/chat/dto/inputs.py
Normal file
34
backend/modules/chat/dto/inputs.py
Normal file
@ -0,0 +1,34 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateChatHistory(BaseModel):
|
||||
chat_id: UUID
|
||||
user_message: str
|
||||
assistant: str
|
||||
prompt_id: Optional[UUID]
|
||||
brain_id: Optional[UUID]
|
||||
|
||||
|
||||
class QuestionAndAnswer(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateChatProperties:
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatUpdatableProperties:
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
def __init__(self, chat_name: Optional[str]):
|
||||
self.chat_name = chat_name
|
21
backend/modules/chat/dto/outputs.py
Normal file
21
backend/modules/chat/dto/outputs.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GetChatHistoryOutput(BaseModel):
|
||||
chat_id: UUID
|
||||
message_id: UUID
|
||||
user_message: str
|
||||
assistant: str
|
||||
message_time: str
|
||||
prompt_title: Optional[str] | None
|
||||
brain_name: Optional[str] | None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
chat_history = super().dict(*args, **kwargs)
|
||||
chat_history["chat_id"] = str(chat_history.get("chat_id"))
|
||||
chat_history["message_id"] = str(chat_history.get("message_id"))
|
||||
|
||||
return chat_history
|
@ -1,26 +1,11 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from models.chat import Chat
|
||||
from models.databases.repository import Repository
|
||||
from pydantic import BaseModel
|
||||
from models.settings import get_supabase_client
|
||||
from modules.chat.entity.chat import Chat
|
||||
from modules.chat.repository.chats_interface import ChatsInterface
|
||||
|
||||
|
||||
class CreateChatHistory(BaseModel):
|
||||
chat_id: UUID
|
||||
user_message: str
|
||||
assistant: str
|
||||
prompt_id: Optional[UUID]
|
||||
brain_id: Optional[UUID]
|
||||
|
||||
|
||||
class QuestionAndAnswer(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class Chats(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
class Chats(ChatsInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
def create_chat(self, new_chat):
|
||||
@ -36,9 +21,7 @@ class Chats(Repository):
|
||||
)
|
||||
return response
|
||||
|
||||
def add_question_and_answer(
|
||||
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
|
||||
) -> Optional[Chat]:
|
||||
def add_question_and_answer(self, chat_id, question_and_answer):
|
||||
response = (
|
||||
self.db.table("chat_history")
|
||||
.insert(
|
||||
@ -66,7 +49,7 @@ class Chats(Repository):
|
||||
|
||||
return reponse
|
||||
|
||||
def get_user_chats(self, user_id: str):
|
||||
def get_user_chats(self, user_id):
|
||||
response = (
|
||||
self.db.from_("chats")
|
||||
.select("chat_id,user_id,creation_time,chat_name")
|
||||
@ -76,7 +59,7 @@ class Chats(Repository):
|
||||
)
|
||||
return response
|
||||
|
||||
def update_chat_history(self, chat_history: CreateChatHistory):
|
||||
def update_chat_history(self, chat_history):
|
||||
response = (
|
||||
self.db.table("chat_history")
|
||||
.insert(
|
||||
@ -114,15 +97,6 @@ class Chats(Repository):
|
||||
|
||||
return response
|
||||
|
||||
def get_chat_details(self, chat_id):
|
||||
response = (
|
||||
self.db.from_("chats")
|
||||
.select("*")
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.execute()
|
||||
)
|
||||
return response
|
||||
|
||||
def delete_chat(self, chat_id):
|
||||
self.db.table("chats").delete().match({"chat_id": chat_id}).execute()
|
||||
|
80
backend/modules/chat/repository/chats_interface.py
Normal file
80
backend/modules/chat/repository/chats_interface.py
Normal file
@ -0,0 +1,80 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from modules.chat.dto.inputs import CreateChatHistory, QuestionAndAnswer
|
||||
from modules.chat.entity.chat import Chat
|
||||
|
||||
|
||||
class ChatsInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_chat(self, new_chat):
|
||||
"""
|
||||
Insert a chat entry in "chats" db
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_by_id(self, chat_id: str):
|
||||
"""
|
||||
Get chat details by chat_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_question_and_answer(
|
||||
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
|
||||
) -> Optional[Chat]:
|
||||
"""
|
||||
Add a question and answer to the chat history
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_history(self, chat_id: str):
|
||||
"""
|
||||
Get chat history by chat_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_chats(self, user_id: str):
|
||||
"""
|
||||
Get all chats for a user
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_chat_history(self, chat_history: CreateChatHistory):
|
||||
"""
|
||||
Update chat history
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_chat(self, chat_id, updates):
|
||||
"""
|
||||
Update chat details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_message_by_id(self, message_id, updates):
|
||||
"""
|
||||
Update message details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_chat(self, chat_id):
|
||||
"""
|
||||
Delete chat
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_chat_history(self, chat_id):
|
||||
"""
|
||||
Delete chat history
|
||||
"""
|
||||
pass
|
169
backend/modules/chat/service/chat_service.py
Normal file
169
backend/modules/chat/service/chat_service.py
Normal file
@ -0,0 +1,169 @@
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatItem
|
||||
from modules.chat.dto.inputs import (
|
||||
ChatUpdatableProperties,
|
||||
CreateChatHistory,
|
||||
CreateChatProperties,
|
||||
QuestionAndAnswer,
|
||||
)
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.entity.chat import Chat, ChatHistory
|
||||
from modules.chat.repository.chats import Chats
|
||||
from modules.chat.repository.chats_interface import ChatsInterface
|
||||
from modules.chat.service.utils import merge_chat_history_and_notifications
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
prompt_service = PromptService()
|
||||
brain_service = BrainService()
|
||||
notification_service = NotificationService()
|
||||
|
||||
|
||||
class ChatService:
|
||||
repository: ChatsInterface
|
||||
|
||||
def __init__(self):
|
||||
self.repository = Chats()
|
||||
|
||||
def create_chat(self, user_id: UUID, chat_data: CreateChatProperties) -> Chat:
|
||||
# Chat is created upon the user's first question asked
|
||||
logger.info(f"New chat entry in chats table for user {user_id}")
|
||||
|
||||
# Insert a new row into the chats table
|
||||
new_chat = {
|
||||
"user_id": str(user_id),
|
||||
"chat_name": chat_data.name,
|
||||
}
|
||||
insert_response = self.repository.create_chat(new_chat)
|
||||
logger.info(f"Insert response {insert_response.data}")
|
||||
|
||||
return insert_response.data[0]
|
||||
|
||||
def add_question_and_answer(
|
||||
self, chat_id: UUID, question_and_answer: QuestionAndAnswer
|
||||
) -> Optional[Chat]:
|
||||
return self.repository.add_question_and_answer(chat_id, question_and_answer)
|
||||
|
||||
def get_chat_by_id(self, chat_id: str) -> Chat:
|
||||
response = self.repository.get_chat_by_id(chat_id)
|
||||
return Chat(response.data[0])
|
||||
|
||||
def get_chat_history(self, chat_id: str) -> List[GetChatHistoryOutput]:
|
||||
history: List[dict] = self.repository.get_chat_history(chat_id).data
|
||||
if history is None:
|
||||
return []
|
||||
else:
|
||||
enriched_history: List[GetChatHistoryOutput] = []
|
||||
for message in history:
|
||||
message = ChatHistory(message)
|
||||
brain = None
|
||||
if message.brain_id:
|
||||
brain = brain_service.get_brain_by_id(message.brain_id)
|
||||
|
||||
prompt = None
|
||||
if message.prompt_id:
|
||||
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
|
||||
|
||||
enriched_history.append(
|
||||
GetChatHistoryOutput(
|
||||
chat_id=(UUID(message.chat_id)),
|
||||
message_id=(UUID(message.message_id)),
|
||||
user_message=message.user_message,
|
||||
assistant=message.assistant,
|
||||
message_time=message.message_time,
|
||||
brain_name=brain.name if brain else None,
|
||||
prompt_title=prompt.title if prompt else None,
|
||||
)
|
||||
)
|
||||
return enriched_history
|
||||
|
||||
def get_chat_history_with_notifications(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
) -> List[ChatItem]:
|
||||
chat_history = self.get_chat_history(str(chat_id))
|
||||
chat_notifications = notification_service.get_chat_notifications(chat_id)
|
||||
return merge_chat_history_and_notifications(chat_history, chat_notifications)
|
||||
|
||||
def get_user_chats(self, user_id: str) -> List[Chat]:
|
||||
response = self.repository.get_user_chats(user_id)
|
||||
chats = [Chat(chat_dict) for chat_dict in response.data]
|
||||
return chats
|
||||
|
||||
def update_chat_history(self, chat_history: CreateChatHistory) -> ChatHistory:
|
||||
response: List[ChatHistory] = (
|
||||
self.repository.update_chat_history(chat_history)
|
||||
).data
|
||||
if len(response) == 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An exception occurred while updating chat history.",
|
||||
)
|
||||
return ChatHistory(response[0]) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def update_chat(self, chat_id, chat_data: ChatUpdatableProperties) -> Chat:
|
||||
if not chat_id:
|
||||
logger.error("No chat_id provided")
|
||||
return # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
updates = {}
|
||||
|
||||
if chat_data.chat_name is not None:
|
||||
updates["chat_name"] = chat_data.chat_name
|
||||
|
||||
updated_chat = None
|
||||
|
||||
if updates:
|
||||
updated_chat = (self.repository.update_chat(chat_id, updates)).data[0]
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for chat {chat_id}")
|
||||
return updated_chat # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def update_message_by_id(
|
||||
self,
|
||||
message_id: str,
|
||||
user_message: str = None, # pyright: ignore reportPrivateUsage=none
|
||||
assistant: str = None, # pyright: ignore reportPrivateUsage=none
|
||||
) -> ChatHistory:
|
||||
if not message_id:
|
||||
logger.error("No message_id provided")
|
||||
return # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
updates = {}
|
||||
|
||||
if user_message is not None:
|
||||
updates["user_message"] = user_message
|
||||
|
||||
if assistant is not None:
|
||||
updates["assistant"] = assistant
|
||||
|
||||
updated_message = None
|
||||
|
||||
if updates:
|
||||
updated_message = (self.repository.update_message_by_id(message_id, updates)).data[ # type: ignore
|
||||
0
|
||||
]
|
||||
logger.info(f"Message {message_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for message {message_id}")
|
||||
return ChatHistory(updated_message) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def delete_chat_from_db(self, chat_id):
|
||||
try:
|
||||
self.repository.delete_chat_history(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
try:
|
||||
self.repository.delete_chat(chat_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
@ -1,28 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
from uuid import UUID
|
||||
from typing import List
|
||||
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatItem, ChatItemType
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.notification.entity.notification import Notification
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from packages.utils import parse_message_time
|
||||
from pydantic import BaseModel
|
||||
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
|
||||
|
||||
|
||||
class ChatItemType(Enum):
|
||||
MESSAGE = "MESSAGE"
|
||||
NOTIFICATION = "NOTIFICATION"
|
||||
|
||||
|
||||
class ChatItem(BaseModel):
|
||||
item_type: ChatItemType
|
||||
body: Union[GetChatHistoryOutput, Notification]
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
prompt_service = PromptService()
|
||||
brain_service = BrainService()
|
||||
notification_service = NotificationService()
|
||||
|
||||
|
||||
# Move these methods to ChatService in chat module
|
||||
def merge_chat_history_and_notifications(
|
||||
chat_history: List[GetChatHistoryOutput], notifications: List[Notification]
|
||||
) -> List[ChatItem]:
|
||||
@ -46,11 +39,3 @@ def merge_chat_history_and_notifications(
|
||||
transformed_data.append(transformed_item)
|
||||
|
||||
return transformed_data
|
||||
|
||||
|
||||
def get_chat_history_with_notifications(
|
||||
chat_id: UUID,
|
||||
) -> List[ChatItem]:
|
||||
chat_history = get_chat_history(str(chat_id))
|
||||
chat_notifications = notification_service.get_chat_notifications(chat_id)
|
||||
return merge_chat_history_and_notifications(chat_history, chat_notifications)
|
@ -1,3 +0,0 @@
|
||||
from modules.brain.service.brain_user_service import BrainUserService
|
||||
|
||||
brain_user_service = BrainUserService()
|
@ -8,6 +8,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_question_context_from_brain(brain_id: UUID, question: str) -> str:
|
||||
# TODO: Move to AnswerGenerator service
|
||||
supabase_client = get_supabase_client()
|
||||
embeddings = get_embeddings()
|
||||
|
||||
|
@ -1,9 +0,0 @@
|
||||
from .create_chat import create_chat, CreateChatProperties
|
||||
from .update_chat import update_chat, ChatUpdatableProperties
|
||||
from .get_user_chats import get_user_chats
|
||||
from .get_chat_by_id import get_chat_by_id
|
||||
from .get_chat_history import GetChatHistoryOutput, get_chat_history
|
||||
from .update_chat_history import update_chat_history
|
||||
from .update_message_by_id import update_message_by_id
|
||||
from .format_chat_history import format_history_to_openai_mesages
|
||||
from .format_chat_history import format_chat_history, format_history_to_openai_mesages
|
@ -1,13 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from models import Chat, get_supabase_db
|
||||
from models.databases.supabase.chats import QuestionAndAnswer
|
||||
|
||||
|
||||
def add_question_and_answer(
|
||||
chat_id: UUID, question_and_answer: QuestionAndAnswer
|
||||
) -> Optional[Chat]:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
return supabase_db.add_question_and_answer(chat_id, question_and_answer)
|
@ -1,32 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from logger import get_logger
|
||||
from models import Chat, get_supabase_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateChatProperties:
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def create_chat(user_id: UUID, chat_data: CreateChatProperties) -> Chat:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
# Chat is created upon the user's first question asked
|
||||
logger.info(f"New chat entry in chats table for user {user_id}")
|
||||
|
||||
# Insert a new row into the chats table
|
||||
new_chat = {
|
||||
"user_id": str(user_id),
|
||||
"chat_name": chat_data.name,
|
||||
}
|
||||
insert_response = supabase_db.create_chat(new_chat)
|
||||
logger.info(f"Insert response {insert_response.data}")
|
||||
|
||||
return insert_response.data[0]
|
@ -1,8 +0,0 @@
|
||||
from models import Chat, get_supabase_db
|
||||
|
||||
|
||||
def get_chat_by_id(chat_id: str) -> Chat:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
response = supabase_db.get_chat_by_id(chat_id)
|
||||
return Chat(response.data[0])
|
@ -1,59 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from models import ChatHistory, get_supabase_db
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from pydantic import BaseModel
|
||||
|
||||
prompt_service = PromptService()
|
||||
|
||||
brain_service = BrainService()
|
||||
|
||||
|
||||
class GetChatHistoryOutput(BaseModel):
|
||||
chat_id: UUID
|
||||
message_id: UUID
|
||||
user_message: str
|
||||
assistant: str
|
||||
message_time: str
|
||||
prompt_title: Optional[str] | None
|
||||
brain_name: Optional[str] | None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
chat_history = super().dict(*args, **kwargs)
|
||||
chat_history["chat_id"] = str(chat_history.get("chat_id"))
|
||||
chat_history["message_id"] = str(chat_history.get("message_id"))
|
||||
|
||||
return chat_history
|
||||
|
||||
|
||||
def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]:
|
||||
supabase_db = get_supabase_db()
|
||||
history: List[dict] = supabase_db.get_chat_history(chat_id).data
|
||||
if history is None:
|
||||
return []
|
||||
else:
|
||||
enriched_history: List[GetChatHistoryOutput] = []
|
||||
for message in history:
|
||||
message = ChatHistory(message)
|
||||
brain = None
|
||||
if message.brain_id:
|
||||
brain = brain_service.get_brain_by_id(message.brain_id)
|
||||
|
||||
prompt = None
|
||||
if message.prompt_id:
|
||||
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
|
||||
|
||||
enriched_history.append(
|
||||
GetChatHistoryOutput(
|
||||
chat_id=(UUID(message.chat_id)),
|
||||
message_id=(UUID(message.message_id)),
|
||||
user_message=message.user_message,
|
||||
assistant=message.assistant,
|
||||
message_time=message.message_time,
|
||||
brain_name=brain.name if brain else None,
|
||||
prompt_title=prompt.title if prompt else None,
|
||||
)
|
||||
)
|
||||
return enriched_history
|
@ -1,10 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from models import Chat, get_supabase_db
|
||||
|
||||
|
||||
def get_user_chats(user_id: str) -> List[Chat]:
|
||||
supabase_db = get_supabase_db()
|
||||
response = supabase_db.get_user_chats(user_id)
|
||||
chats = [Chat(chat_dict) for chat_dict in response.data]
|
||||
return chats
|
@ -1,37 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from logger import get_logger
|
||||
from models import Chat, get_supabase_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatUpdatableProperties:
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
def __init__(self, chat_name: Optional[str]):
|
||||
self.chat_name = chat_name
|
||||
|
||||
|
||||
def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
if not chat_id:
|
||||
logger.error("No chat_id provided")
|
||||
return # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
updates = {}
|
||||
|
||||
if chat_data.chat_name is not None:
|
||||
updates["chat_name"] = chat_data.chat_name
|
||||
|
||||
updated_chat = None
|
||||
|
||||
if updates:
|
||||
updated_chat = (supabase_db.update_chat(chat_id, updates)).data[0]
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for chat {chat_id}")
|
||||
return updated_chat # pyright: ignore reportPrivateUsage=none
|
@ -1,15 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from models import ChatHistory, get_supabase_db
|
||||
|
||||
|
||||
def update_chat_history(chat_history: CreateChatHistory) -> ChatHistory:
|
||||
supabase_db = get_supabase_db()
|
||||
response: List[ChatHistory] = (supabase_db.update_chat_history(chat_history)).data
|
||||
if len(response) == 0:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An exception occurred while updating chat history."
|
||||
)
|
||||
return ChatHistory(response[0]) # pyright: ignore reportPrivateUsage=none
|
@ -1,35 +0,0 @@
|
||||
from logger import get_logger
|
||||
from models import ChatHistory, get_supabase_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def update_message_by_id(
|
||||
message_id: str,
|
||||
user_message: str = None, # pyright: ignore reportPrivateUsage=none
|
||||
assistant: str = None, # pyright: ignore reportPrivateUsage=none
|
||||
) -> ChatHistory:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
if not message_id:
|
||||
logger.error("No message_id provided")
|
||||
return # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
updates = {}
|
||||
|
||||
if user_message is not None:
|
||||
updates["user_message"] = user_message
|
||||
|
||||
if assistant is not None:
|
||||
updates["assistant"] = assistant
|
||||
|
||||
updated_message = None
|
||||
|
||||
if updates:
|
||||
updated_message = (supabase_db.update_message_by_id(message_id, updates)).data[ # type: ignore
|
||||
0
|
||||
]
|
||||
logger.info(f"Message {message_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for message {message_id}")
|
||||
return ChatHistory(updated_message) # pyright: ignore reportPrivateUsage=none
|
@ -1,2 +0,0 @@
|
||||
from .create_secret import create_secret
|
||||
from .read_secret import read_secret
|
@ -1,21 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models import get_supabase_client
|
||||
from repository.external_api_secret.utils import build_secret_unique_name
|
||||
|
||||
|
||||
def create_secret(
|
||||
user_id: UUID, brain_id: UUID, secret_name: str, secret_value
|
||||
) -> UUID | None:
|
||||
supabase_client = get_supabase_client()
|
||||
response = supabase_client.rpc(
|
||||
"insert_secret",
|
||||
{
|
||||
"name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
"secret": secret_value,
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
@ -1,23 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models import get_supabase_client
|
||||
|
||||
from repository.external_api_secret.utils import build_secret_unique_name
|
||||
|
||||
|
||||
def read_secret(
|
||||
user_id: UUID,
|
||||
brain_id: UUID,
|
||||
secret_name: str,
|
||||
) -> UUID | None:
|
||||
supabase_client = get_supabase_client()
|
||||
response = supabase_client.rpc(
|
||||
"read_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
@ -1,26 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from repository.external_api_secret.create_secret import create_secret
|
||||
|
||||
brain_service = BrainService()
|
||||
|
||||
|
||||
def update_secret_value(
|
||||
user_id: UUID,
|
||||
brain_id: UUID,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
) -> None:
|
||||
"""Update an existing secret."""
|
||||
brain_service.delete_secret(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret_name,
|
||||
)
|
||||
create_secret(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
)
|
@ -1,5 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def build_secret_unique_name(user_id: UUID, brain_id: UUID, secret_name: str):
|
||||
return f"{user_id}-{brain_id}-{secret_name}"
|
@ -17,7 +17,6 @@ from modules.brain.service.brain_user_service import BrainUserService
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from repository.brain import get_question_context_from_brain
|
||||
from repository.external_api_secret.update_secret_value import update_secret_value
|
||||
|
||||
logger = get_logger(__name__)
|
||||
brain_router = APIRouter()
|
||||
@ -197,7 +196,7 @@ async def update_existing_brain_secrets(
|
||||
detail=f"Secret {key} is not a valid secret.",
|
||||
)
|
||||
if value:
|
||||
update_secret_value(
|
||||
brain_service.update_secret_value(
|
||||
user_id=current_user.id,
|
||||
brain_id=brain_id,
|
||||
secret_name=key,
|
||||
@ -226,6 +225,7 @@ async def set_brain_as_default(
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def get_question_context_for_brain(brain_id: UUID, request: BrainQuestionRequest):
|
||||
# TODO: Move this endpoint to AnswerGenerator service
|
||||
"""Retrieve the question context from a specific brain."""
|
||||
context = get_question_context_from_brain(brain_id, request.question)
|
||||
return {"context": context}
|
||||
|
@ -23,7 +23,6 @@ from repository.brain_subscription import (
|
||||
SubscriptionInvitationService,
|
||||
resend_invitation_email,
|
||||
)
|
||||
from repository.external_api_secret.create_secret import create_secret
|
||||
from routes.headers.get_origin_header import get_origin_header
|
||||
|
||||
subscription_router = APIRouter()
|
||||
@ -420,7 +419,7 @@ async def subscribe_to_brain_handler(
|
||||
)
|
||||
|
||||
for secret in brain_secrets:
|
||||
create_secret(
|
||||
brain_service.external_api_secrets_repository.create_secret(
|
||||
user_id=current_user.id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret.name,
|
||||
|
Loading…
Reference in New Issue
Block a user