diff --git a/.vscode/settings.json b/.vscode/settings.json index 60757ebb9..15ede991a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,8 @@ "python.formatting.provider": "black", "editor.codeActionsOnSave": { "source.organizeImports": true, - "source.fixAll": true + "source.fixAll": true, + "source.unusedImports": true }, "python.linting.enabled": true, "python.linting.flake8Enabled": true, diff --git a/backend/llm/OpenAiFunctionBasedAnswerGenerator/OpenAiFunctionBasedAnswerGenerator.py b/backend/llm/OpenAiFunctionBasedAnswerGenerator/OpenAiFunctionBasedAnswerGenerator.py new file mode 100644 index 000000000..59033e5b2 --- /dev/null +++ b/backend/llm/OpenAiFunctionBasedAnswerGenerator/OpenAiFunctionBasedAnswerGenerator.py @@ -0,0 +1,245 @@ +from typing import Optional + + +from typing import Any, Dict, List # For type hinting +from langchain.chat_models import ChatOpenAI +from repository.chat.get_chat_history import get_chat_history +from .utils.format_answer import format_answer + + +# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing +from langchain.embeddings.openai import OpenAIEmbeddings + + +from models.settings import BrainSettings # Importing settings related to the 'brain' +from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer + +from supabase import Client, create_client # For interacting with Supabase database +from vectorstore.supabase import ( + CustomSupabaseVectorStore, +) # Custom class for handling vector storage with Supabase +from logger import get_logger + +logger = get_logger(__name__) + +get_context_function_name = "get_context" + +prompt_template = """Your name is Quivr. You are a second brain. +A person will ask you a question and you will provide a helpful answer. +Write the answer in the same language as the question. +If you don't know the answer, just say that you don't know. Don't try to make up an answer. +Your main goal is to answer questions about user uploaded documents. Unless basic questions or greetings, you should always refer to user uploaded documents by fetching them with the {} function.""".format( + get_context_function_name +) + +get_history_schema = { + "name": "get_history", + "description": "Get current chat previous messages", + "parameters": { + "type": "object", + "properties": {}, + }, +} + +get_context_schema = { + "name": get_context_function_name, + "description": "A function which returns user uploaded documents and which must be used when you don't now the answer to a question or when the question seems to refer to user uploaded documents", + "parameters": { + "type": "object", + "properties": {}, + }, +} + + +class OpenAiFunctionBasedAnswerGenerator: + # Default class attributes + model: str = "gpt-3.5-turbo-0613" + temperature: float = 0.0 + max_tokens: int = 256 + chat_id: str + supabase_client: Client = None + embeddings: OpenAIEmbeddings = None + settings = BrainSettings() + openai_client: ChatOpenAI = None + user_email: str + + def __init__( + self, + model: str, + chat_id: str, + temperature: float, + max_tokens: int, + user_email: str, + user_openai_api_key: str, + ) -> "OpenAiFunctionBasedAnswerGenerator": + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + self.chat_id = chat_id + self.supabase_client = create_client( + self.settings.supabase_url, self.settings.supabase_service_key + ) + self.user_email = user_email + if user_openai_api_key is not None: + self.settings.openai_api_key = user_openai_api_key + self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key) + self.openai_client = ChatOpenAI(openai_api_key=self.settings.openai_api_key) + + def _get_model_response( + self, + messages: list[dict[str, str]] = [], + functions: list[dict[str, Any]] = None, + ): + if functions is not None: + model_response = self.openai_client.completion_with_retry( + functions=functions, + messages=messages, + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + else: + model_response = self.openai_client.completion_with_retry( + messages=messages, + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + return model_response + + def _get_formatted_history(self) -> List[Dict[str, str]]: + formatted_history = [] + history = get_chat_history(self.chat_id) + for chat in history: + formatted_history.append({"role": "user", "content": chat.user_message}) + formatted_history.append({"role": "assistant", "content": chat.assistant}) + return formatted_history + + def _get_formatted_prompt( + self, + question: Optional[str], + useContext: Optional[bool] = False, + useHistory: Optional[bool] = False, + ) -> list[dict[str, str]]: + messages = [ + {"role": "system", "content": prompt_template}, + ] + + if not useHistory and not useContext: + messages.append( + {"role": "user", "content": question}, + ) + return messages + + if useHistory: + history = self._get_formatted_history() + if len(history): + messages.append( + { + "role": "system", + "content": "Previous messages are already in chat.", + }, + ) + messages.extend(history) + else: + messages.append( + { + "role": "user", + "content": "This is the first message of the chat. There is no previous one", + } + ) + + messages.append( + { + "role": "user", + "content": f"Question: {question}\n\nAnswer:", + } + ) + if useContext: + chat_context = self._get_context(question) + enhanced_question = f"Here is chat context: {chat_context if len(chat_context) else 'No document found'}" + messages.append({"role": "user", "content": enhanced_question}) + messages.append( + { + "role": "user", + "content": f"Question: {question}\n\nAnswer:", + } + ) + return messages + + def _get_context(self, question: str) -> str: + # retrieve 5 nearest documents + vector_store = CustomSupabaseVectorStore( + self.supabase_client, + self.embeddings, + table_name="vectors", + user_id=self.user_email, + ) + + return vector_store.similarity_search( + query=question, + ) + + def _get_answer_from_question(self, question: str) -> OpenAiAnswer: + functions = [get_history_schema, get_context_schema] + + model_response = self._get_model_response( + messages=self._get_formatted_prompt(question=question), + functions=functions, + ) + + return format_answer(model_response) + + def _get_answer_from_question_and_history(self, question: str) -> OpenAiAnswer: + logger.info("Using chat history") + + functions = [ + get_context_schema, + ] + + model_response = self._get_model_response( + messages=self._get_formatted_prompt(question=question, useHistory=True), + functions=functions, + ) + + return format_answer(model_response) + + def _get_answer_from_question_and_context(self, question: str) -> OpenAiAnswer: + logger.info("Using documents ") + + functions = [ + get_history_schema, + ] + model_response = self._get_model_response( + messages=self._get_formatted_prompt(question=question, useContext=True), + functions=functions, + ) + + return format_answer(model_response) + + def _get_answer_from_question_and_context_and_history( + self, question: str + ) -> OpenAiAnswer: + logger.info("Using context and history") + model_response = self._get_model_response( + messages=self._get_formatted_prompt( + question, useContext=True, useHistory=True + ), + ) + return format_answer(model_response) + + def get_answer(self, question: str) -> str: + response = self._get_answer_from_question(question) + function_name = response.function_call.name if response.function_call else None + + if function_name == "get_history": + response = self._get_answer_from_question_and_history(question) + elif function_name == "get_context": + response = self._get_answer_from_question_and_context(question) + + if response.function_call: + response = self._get_answer_from_question_and_context_and_history(question) + + return response.content or "" diff --git a/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/FunctionCall.py b/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/FunctionCall.py new file mode 100644 index 000000000..36640c259 --- /dev/null +++ b/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/FunctionCall.py @@ -0,0 +1,12 @@ +from typing import Optional +from typing import Any, Dict # For type hinting + + +class FunctionCall: + def __init__( + self, + name: Optional[str] = None, + arguments: Optional[Dict[str, Any]] = None, + ): + self.name = name + self.arguments = arguments diff --git a/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/OpenAiAnswer.py b/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/OpenAiAnswer.py new file mode 100644 index 000000000..dc09850f4 --- /dev/null +++ b/backend/llm/OpenAiFunctionBasedAnswerGenerator/models/OpenAiAnswer.py @@ -0,0 +1,12 @@ +from typing import Optional +from .FunctionCall import FunctionCall + + +class OpenAiAnswer: + def __init__( + self, + content: Optional[str] = None, + function_call: FunctionCall = None, + ): + self.content = content + self.function_call = function_call diff --git a/backend/llm/OpenAiFunctionBasedAnswerGenerator/utils/format_answer.py b/backend/llm/OpenAiFunctionBasedAnswerGenerator/utils/format_answer.py new file mode 100644 index 000000000..84ebf59e4 --- /dev/null +++ b/backend/llm/OpenAiFunctionBasedAnswerGenerator/utils/format_answer.py @@ -0,0 +1,17 @@ +from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer +from llm.OpenAiFunctionBasedAnswerGenerator.models.FunctionCall import FunctionCall +from typing import Any, Dict # For type hinting + + +def format_answer(model_response: Dict[str, Any]) -> OpenAiAnswer: + answer = model_response["choices"][0]["message"] + content = answer["content"] + function_call = None + + if answer.get("function_call", None) is not None: + function_call = FunctionCall( + answer["function_call"]["name"], + answer["function_call"]["arguments"], + ) + + return OpenAiAnswer(content=content, function_call=function_call) diff --git a/backend/llm/brainpicking.py b/backend/llm/brainpicking.py index d7c1e6422..94e640e86 100644 --- a/backend/llm/brainpicking.py +++ b/backend/llm/brainpicking.py @@ -1,32 +1,21 @@ -import os # A module to interact with the OS -from typing import Any, Dict, List +from typing import Any, Dict from models.settings import LLMSettings # For type hinting # Importing various modules and classes from a custom library 'langchain' likely used for natural language processing from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser -from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE -from langchain.chat_models import ChatOpenAI, ChatVertexAI -from langchain.chat_models.anthropic import ChatAnthropic -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.llms import GPT4All from langchain.llms.base import LLM from langchain.memory import ConversationBufferMemory -from langchain.vectorstores import SupabaseVectorStore -from llm.prompt import LANGUAGE_PROMPT + from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT -from models.chats import ( - ChatMessage, -) # Importing a custom ChatMessage class for handling chat messages from models.settings import BrainSettings # Importing settings related to the 'brain' from pydantic import BaseModel # For data validation and settings management -from pydantic import BaseSettings from supabase import Client # For interacting with Supabase database from supabase import create_client +from repository.chat.get_chat_history import get_chat_history from vectorstore.supabase import ( CustomSupabaseVectorStore, ) # Custom class for handling vector storage with Supabase @@ -34,6 +23,7 @@ from logger import get_logger logger = get_logger(__name__) + class AnswerConversationBufferMemory(ConversationBufferMemory): """ This class is a specialized version of ConversationBufferMemory. @@ -48,7 +38,7 @@ class AnswerConversationBufferMemory(ConversationBufferMemory): ) -def get_chat_history(inputs) -> str: +def format_chat_history(inputs) -> str: """ Function to concatenate chat history into a single string. :param inputs: List of tuples containing human and AI messages. @@ -76,18 +66,38 @@ class BrainPicking(BaseModel): llm: LLM = None question_generator: LLMChain = None doc_chain: ConversationalRetrievalChain = None + chat_id: str + max_tokens: int = 256 class Config: # Allowing arbitrary types for class validation arbitrary_types_allowed = True - def init(self, model: str, user_id: str) -> "BrainPicking": + def __init__( + self, + model: str, + user_id: str, + chat_id: str, + max_tokens: int, + user_openai_api_key: str, + ) -> "BrainPicking": """ Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains. :param model: Language model name to be used. :param user_id: The user id to be used for CustomSupabaseVectorStore. :return: BrainPicking instance """ + super().__init__( + model=model, + user_id=user_id, + chat_id=chat_id, + max_tokens=max_tokens, + user_openai_api_key=user_openai_api_key, + ) + # If user provided an API key, update the settings + if user_openai_api_key is not None: + self.settings.openai_api_key = user_openai_api_key + self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key) self.supabase_client = create_client( self.settings.supabase_url, self.settings.supabase_service_key @@ -98,7 +108,7 @@ class BrainPicking(BaseModel): table_name="vectors", user_id=user_id, ) - + self.llm = self._determine_llm( private_model_args={ "model_path": self.llm_config.model_path, @@ -112,7 +122,8 @@ class BrainPicking(BaseModel): llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT ) self.doc_chain = load_qa_chain(self.llm, chain_type="stuff") - return self + self.chat_id = chat_id + self.max_tokens = max_tokens def _determine_llm( self, private_model_args: dict, private: bool = False, model_name: str = None @@ -128,7 +139,7 @@ class BrainPicking(BaseModel): model_path = private_model_args["model_path"] model_n_ctx = private_model_args["n_ctx"] model_n_batch = private_model_args["n_batch"] - + logger.info("Using private model: %s", model_path) return GPT4All( @@ -142,7 +153,7 @@ class BrainPicking(BaseModel): return ChatOpenAI(temperature=0, model_name=model_name) def _get_qa( - self, chat_message: ChatMessage, user_openai_api_key + self, ) -> ConversationalRetrievalChain: """ Retrieves a QA chain for the given chat message and API key. @@ -150,42 +161,34 @@ class BrainPicking(BaseModel): :param user_openai_api_key: The OpenAI API key to be used. :return: ConversationalRetrievalChain instance """ - # If user provided an API key, update the settings - if user_openai_api_key is not None and user_openai_api_key != "": - self.settings.openai_api_key = user_openai_api_key # Initialize and return a ConversationalRetrievalChain qa = ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), - max_tokens_limit=chat_message.max_tokens, + max_tokens_limit=self.max_tokens, question_generator=self.question_generator, combine_docs_chain=self.doc_chain, - get_chat_history=get_chat_history, + get_chat_history=format_chat_history, ) return qa - def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str: + def generate_answer(self, question: str) -> str: """ - Generate an answer to a given chat message by interacting with the language model. - :param chat_message: The chat message containing history. - :param user_openai_api_key: The OpenAI API key to be used. + Generate an answer to a given question by interacting with the language model. + :param question: The question :return: The generated answer. """ transformed_history = [] # Get the QA chain - qa = self._get_qa(chat_message, user_openai_api_key) + qa = self._get_qa() + history = get_chat_history(self.chat_id) - # Transform the chat history into a list of tuples - for i in range(0, len(chat_message.history) - 1, 2): - user_message = chat_message.history[i][1] - assistant_message = chat_message.history[i + 1][1] - transformed_history.append((user_message, assistant_message)) + # Format the chat history into a list of tuples (human, ai) + transformed_history = [(chat.user_message, chat.assistant) for chat in history] # Generate the model response using the QA chain - model_response = qa( - {"question": chat_message.question, "chat_history": transformed_history} - ) + model_response = qa({"question": question, "chat_history": transformed_history}) answer = model_response["answer"] return answer diff --git a/backend/llm/prompt/CONDENSE_PROMPT.py b/backend/llm/prompt/CONDENSE_PROMPT.py index c0c2bdc01..86f54da52 100644 --- a/backend/llm/prompt/CONDENSE_PROMPT.py +++ b/backend/llm/prompt/CONDENSE_PROMPT.py @@ -6,4 +6,4 @@ Chat History: {chat_history} Follow Up Input: {question} Standalone question:""" -CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) \ No newline at end of file +CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) diff --git a/backend/llm/prompt/LANGUAGE_PROMPT.py b/backend/llm/prompt/LANGUAGE_PROMPT.py index 0c13c3aa4..0139e594e 100644 --- a/backend/llm/prompt/LANGUAGE_PROMPT.py +++ b/backend/llm/prompt/LANGUAGE_PROMPT.py @@ -9,4 +9,4 @@ Question: {question} Helpful Answer:""" QA_PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] - ) \ No newline at end of file +) diff --git a/backend/logger.py b/backend/logger.py index 0962a2d89..21a896de2 100644 --- a/backend/logger.py +++ b/backend/logger.py @@ -6,8 +6,7 @@ def get_logger(logger_name, log_level=logging.INFO): logger.setLevel(log_level) logger.propagate = False # Prevent log propagation to avoid double logging - formatter = logging.Formatter( - '%(asctime)s [%(levelname)s] %(name)s: %(message)s') + formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) diff --git a/backend/main.py b/backend/main.py index 6f5f53130..2034ba58e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,7 +1,7 @@ import os - +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse import pypandoc -from fastapi import FastAPI from logger import get_logger from middlewares.cors import add_cors_middleware from routes.api_key_routes import api_key_router @@ -37,3 +37,11 @@ app.include_router(upload_router) app.include_router(user_router) app.include_router(api_key_router) app.include_router(stream_router) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + ) diff --git a/backend/models/chat.py b/backend/models/chat.py new file mode 100644 index 000000000..9d8da514b --- /dev/null +++ b/backend/models/chat.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + + +@dataclass +class Chat: + chat_id: str + user_id: str + creation_time: str + chat_name: str + + def __init__(self, chat_dict: dict): + self.chat_id = chat_dict.get("chat_id") + self.user_id = chat_dict.get("user_id") + self.creation_time = chat_dict.get("creation_time") + self.chat_name = chat_dict.get("chat_name") + + +@dataclass +class ChatHistory: + chat_id: str + message_id: str + user_message: str + assistant: str + message_time: str + + def __init__(self, chat_dict: dict): + self.chat_id = chat_dict.get("chat_id") + self.message_id = chat_dict.get("message_id") + self.user_message = chat_dict.get("user_message") + self.assistant = chat_dict.get("assistant") + self.message_time = chat_dict.get("message_time") diff --git a/backend/models/chats.py b/backend/models/chats.py index 502afc841..c89d2da65 100644 --- a/backend/models/chats.py +++ b/backend/models/chats.py @@ -16,5 +16,8 @@ class ChatMessage(BaseModel): chat_name: Optional[str] = None -class ChatAttributes(BaseModel): - chat_name: Optional[str] = None +class ChatQuestion(BaseModel): + model: str = "gpt-3.5-turbo-0613" + question: str + temperature: float = 0.0 + max_tokens: int = 256 diff --git a/backend/models/settings.py b/backend/models/settings.py index 8f9192614..b4774015a 100644 --- a/backend/models/settings.py +++ b/backend/models/settings.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Dict, List, Tuple, Union +from typing import Annotated from fastapi import Depends from langchain.embeddings.openai import OpenAIEmbeddings @@ -13,26 +13,33 @@ class BrainSettings(BaseSettings): supabase_url: str supabase_service_key: str + class LLMSettings(BaseSettings): private: bool model_path: str model_n_ctx: int model_n_batch: int + def common_dependencies() -> dict: settings = BrainSettings() embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key) - supabase_client: Client = create_client(settings.supabase_url, settings.supabase_service_key) + supabase_client: Client = create_client( + settings.supabase_url, settings.supabase_service_key + ) documents_vector_store = SupabaseVectorStore( - supabase_client, embeddings, table_name="vectors") + supabase_client, embeddings, table_name="vectors" + ) summaries_vector_store = SupabaseVectorStore( - supabase_client, embeddings, table_name="summaries") - + supabase_client, embeddings, table_name="summaries" + ) + return { "supabase": supabase_client, "embeddings": embeddings, "documents_vector_store": documents_vector_store, - "summaries_vector_store": summaries_vector_store + "summaries_vector_store": summaries_vector_store, } -CommonsDep = Annotated[dict, Depends(common_dependencies)] \ No newline at end of file + +CommonsDep = Annotated[dict, Depends(common_dependencies)] diff --git a/backend/repository/chat/create_chat.py b/backend/repository/chat/create_chat.py new file mode 100644 index 000000000..b5c17f5e6 --- /dev/null +++ b/backend/repository/chat/create_chat.py @@ -0,0 +1,32 @@ +from logger import get_logger +from models.settings import common_dependencies +from dataclasses import dataclass +from models.chat import Chat + + +logger = get_logger(__name__) + + +@dataclass +class CreateChatProperties: + name: str + + def __init__(self, name: str): + self.name = name + + +def create_chat(user_id: str, chat_data: CreateChatProperties) -> Chat: + commons = common_dependencies() + + # Chat is created upon the user's first question asked + logger.info(f"New chat entry in chats table for user {user_id}") + + # Insert a new row into the chats table + new_chat = { + "user_id": user_id, + "chat_name": chat_data.name, + } + insert_response = commons["supabase"].table("chats").insert(new_chat).execute() + logger.info(f"Insert response {insert_response.data}") + + return insert_response.data[0] diff --git a/backend/repository/chat/get_chat_by_id.py b/backend/repository/chat/get_chat_by_id.py new file mode 100644 index 000000000..5cebfa16c --- /dev/null +++ b/backend/repository/chat/get_chat_by_id.py @@ -0,0 +1,15 @@ +from models.chat import Chat +from models.settings import common_dependencies + + +def get_chat_by_id(chat_id: str) -> Chat: + commons = common_dependencies() + + response = ( + commons["supabase"] + .from_("chats") + .select("*") + .filter("chat_id", "eq", chat_id) + .execute() + ) + return Chat(response.data[0]) diff --git a/backend/repository/chat/get_chat_history.py b/backend/repository/chat/get_chat_history.py new file mode 100644 index 000000000..4de83b4e5 --- /dev/null +++ b/backend/repository/chat/get_chat_history.py @@ -0,0 +1,19 @@ +from models.chat import ChatHistory +from models.settings import common_dependencies +from typing import List # For type hinting + + +def get_chat_history(chat_id: str) -> List[ChatHistory]: + commons = common_dependencies() + history: List[ChatHistory] = ( + commons["supabase"] + .from_("chat_history") + .select("*") + .filter("chat_id", "eq", chat_id) + .order("message_time", desc=False) # Add the ORDER BY clause + .execute() + ).data + if history is None: + return [] + else: + return [ChatHistory(message) for message in history] diff --git a/backend/repository/chat/get_user_chats.py b/backend/repository/chat/get_user_chats.py new file mode 100644 index 000000000..107c81edf --- /dev/null +++ b/backend/repository/chat/get_user_chats.py @@ -0,0 +1,15 @@ +from models.settings import common_dependencies +from models.chat import Chat + + +def get_user_chats(user_id: str) -> list[Chat]: + commons = common_dependencies() + response = ( + commons["supabase"] + .from_("chats") + .select("chat_id,user_id,creation_time,chat_name") + .filter("user_id", "eq", user_id) + .execute() + ) + chats = [Chat(chat_dict) for chat_dict in response.data] + return chats diff --git a/backend/repository/chat/update_chat.py b/backend/repository/chat/update_chat.py new file mode 100644 index 000000000..fea0cf55b --- /dev/null +++ b/backend/repository/chat/update_chat.py @@ -0,0 +1,45 @@ +from logger import get_logger +from models.chat import Chat +from typing import Optional +from dataclasses import dataclass + +from models.settings import common_dependencies + + +logger = get_logger(__name__) + + +@dataclass +class ChatUpdatableProperties: + chat_name: Optional[str] = None + + def __init__(self, chat_name: Optional[str]): + self.chat_name = chat_name + + +def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat: + commons = common_dependencies() + + if not chat_id: + logger.error("No chat_id provided") + return + + updates = {} + + if chat_data.chat_name is not None: + updates["chat_name"] = chat_data.chat_name + + updated_chat = None + + if updates: + updated_chat = ( + commons["supabase"] + .table("chats") + .update(updates) + .match({"chat_id": chat_id}) + .execute() + ).data[0] + logger.info(f"Chat {chat_id} updated") + else: + logger.info(f"No updates to apply for chat {chat_id}") + return updated_chat diff --git a/backend/repository/chat/update_chat_history.py b/backend/repository/chat/update_chat_history.py new file mode 100644 index 000000000..b42cf1e6f --- /dev/null +++ b/backend/repository/chat/update_chat_history.py @@ -0,0 +1,24 @@ +from models.chat import ChatHistory +from models.settings import common_dependencies +from typing import List # For type hinting + + +def update_chat_history( + chat_id: str, user_message: str, assistant_answer: str +) -> ChatHistory: + commons = common_dependencies() + response: List[ChatHistory] = ( + commons["supabase"] + .table("chat_history") + .insert( + { + "chat_id": str(chat_id), + "user_message": user_message, + "assistant": assistant_answer, + } + ) + .execute() + ).data + if len(response) == 0: + raise Exception("Error while updating chat history") + return response[0] diff --git a/backend/requirements.txt b/backend/requirements.txt index 4a9817bdb..89a14a051 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,5 @@ pymupdf==1.22.3 -langchain==0.0.200 +langchain==0.0.207 Markdown==3.4.3 openai==0.27.6 pdf2image==1.16.3 diff --git a/backend/routes/brain_routes.py b/backend/routes/brain_routes.py index cb24f60c1..b6838b62e 100644 --- a/backend/routes/brain_routes.py +++ b/backend/routes/brain_routes.py @@ -1,5 +1,3 @@ -import os -import time from typing import Optional from uuid import UUID @@ -7,7 +5,7 @@ from auth.auth_bearer import AuthBearer, get_current_user from fastapi import APIRouter, Depends, Request from logger import get_logger from models.brains import Brain -from models.settings import CommonsDep, common_dependencies +from models.settings import common_dependencies from models.users import User from pydantic import BaseModel from utils.users import fetch_user_id_from_credentials @@ -50,7 +48,7 @@ async def brain_endpoint(current_user: User = Depends(get_current_user)): @brain_router.get( "/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"] ) -async def brain_endpoint(brain_id: UUID): +async def get_brain_endpoint(brain_id: UUID): """ Retrieve details of a specific brain by brain ID. @@ -76,7 +74,7 @@ async def brain_endpoint(brain_id: UUID): @brain_router.delete( "/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"] ) -async def brain_endpoint(brain_id: UUID): +async def delete_brain_endpoint(brain_id: UUID): """ Delete a specific brain by brain ID. """ @@ -97,7 +95,7 @@ class BrainObject(BaseModel): # create new brain @brain_router.post("/brains", dependencies=[Depends(AuthBearer())], tags=["Brain"]) -async def brain_endpoint( +async def create_brain_endpoint( request: Request, brain: BrainObject, current_user: User = Depends(get_current_user), @@ -125,7 +123,7 @@ async def brain_endpoint( @brain_router.put( "/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"] ) -async def brain_endpoint( +async def update_brain_endpoint( request: Request, brain_id: UUID, input_brain: Brain, diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index 19a8f7865..f5c3e67ae 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -1,32 +1,37 @@ import os import time from uuid import UUID +from models.chat import ChatHistory from auth.auth_bearer import AuthBearer, get_current_user from fastapi import APIRouter, Depends, Request from llm.brainpicking import BrainPicking -from models.chats import ChatMessage, ChatAttributes -from models.settings import CommonsDep, common_dependencies +from fastapi import HTTPException +from models.chat import Chat +from models.chats import ChatQuestion +from models.settings import common_dependencies from models.users import User -from utils.chats import (create_chat, get_chat_name_from_first_question, - update_chat) -from utils.users import (create_user, fetch_user_id_from_credentials, - update_user_request_count) -from http.client import HTTPException +from repository.chat.get_chat_history import get_chat_history +from repository.chat.update_chat import update_chat, ChatUpdatableProperties +from repository.chat.create_chat import create_chat, CreateChatProperties +from utils.users import ( + fetch_user_id_from_credentials, + update_user_request_count, +) + + +from repository.chat.get_chat_by_id import get_chat_by_id + +from repository.chat.get_user_chats import get_user_chats +from repository.chat.update_chat_history import update_chat_history + + +from llm.OpenAiFunctionBasedAnswerGenerator.OpenAiFunctionBasedAnswerGenerator import ( + OpenAiFunctionBasedAnswerGenerator, +) chat_router = APIRouter() -def get_user_chats(commons, user_id): - response = ( - commons["supabase"] - .from_("chats") - .select("chatId:chat_id, chatName:chat_name") - .filter("user_id", "eq", user_id) - .execute() - ) - return response.data - - def get_chat_details(commons, chat_id): response = ( commons["supabase"] @@ -69,32 +74,14 @@ async def get_chats(current_user: User = Depends(get_current_user)): """ commons = common_dependencies() user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email}) - chats = get_user_chats(commons, user_id) + chats = get_user_chats(user_id) return {"chats": chats} -# get one chat -@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def get_chat_handler(chat_id: UUID): - """ - Retrieve details of a specific chat by chat ID. - - - `chat_id`: The ID of the chat to retrieve details for. - - Returns the chat ID and its history. - - This endpoint retrieves the details of a specific chat identified by the provided chat ID. It returns the chat ID and its - history, which includes the chat messages exchanged in the chat. - """ - commons = common_dependencies() - chats = get_chat_details(commons, chat_id) - if len(chats) > 0: - return {"chatId": chat_id, "history": chats[0]["history"]} - else: - return {"error": "Chat not found"} - - # delete one chat -@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) +@chat_router.delete( + "/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"] +) async def delete_chat(chat_id: UUID): """ Delete a specific chat by chat ID. @@ -104,91 +91,118 @@ async def delete_chat(chat_id: UUID): return {"message": f"{chat_id} has been deleted."} -# helper method for update and create chat -def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False): - date = time.strftime("%Y%m%d") - user_id = fetch_user_id_from_credentials(commons, {"email": email}) - max_requests_number = os.getenv("MAX_REQUESTS_NUMBER") - user_openai_api_key = request.headers.get("Openai-Api-Key") - - userItem = fetch_user_stats(commons, User(email=email), date) - old_request_count = userItem["requests_count"] - - history = chat_message.history - history.append(("user", chat_message.question)) - - chat_name = chat_message.chat_name - - if old_request_count == 0: - create_user(commons, email=email, date=date) - else: - update_user_request_count( - commons, email, date, requests_count=old_request_count + 1 - ) - if user_openai_api_key is None and old_request_count >= float(max_requests_number): - history.append(("assistant", "You have reached your requests limit")) - update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name) - return {"history": history} - - brainPicking = BrainPicking().init(chat_message.model, email) - answer = brainPicking.generate_answer(chat_message, user_openai_api_key) - history.append(("assistant", answer)) - - if is_new_chat: - chat_name = get_chat_name_from_first_question(chat_message) - new_chat = create_chat(commons, user_id, history, chat_name) - chat_id = new_chat.data[0]["chat_id"] - else: - update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name) - - return {"history": history, "chatId": chat_id} - - -# update existing chat -@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def chat_endpoint( - request: Request, - commons: CommonsDep, - chat_id: UUID, - chat_message: ChatMessage, - current_user: User = Depends(get_current_user), -): - """ - Update an existing chat with new chat messages. - """ - return chat_handler(request, commons, chat_id, chat_message, current_user.email) - - -# update existing chat -@chat_router.put("/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def update_chat_attributes_handler( - commons: CommonsDep, - chat_message: ChatAttributes, +# update existing chat metadata +@chat_router.put( + "/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"] +) +async def update_chat_metadata_handler( + chat_data: ChatUpdatableProperties, chat_id: UUID, current_user: User = Depends(get_current_user), -): +) -> Chat: """ Update chat attributes """ + commons = common_dependencies() user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email}) - chat = get_chat_details(commons, chat_id)[0] - if user_id != chat.get('user_id'): + chat = get_chat_by_id(chat_id) + if user_id != chat.user_id: raise HTTPException(status_code=403, detail="Chat not owned by user") - return update_chat(commons=commons, chat_id=chat_id, chat_name=chat_message.chat_name) + return update_chat(chat_id=chat_id, chat_data=chat_data) + + +# helper method for update and create chat +def check_user_limit( + email, +): + date = time.strftime("%Y%m%d") + max_requests_number = os.getenv("MAX_REQUESTS_NUMBER") + commons = common_dependencies() + userItem = fetch_user_stats(commons, User(email=email), date) + old_request_count = userItem["requests_count"] + + update_user_request_count( + commons, email, date, requests_count=old_request_count + 1 + ) + if old_request_count >= float(max_requests_number): + raise HTTPException( + status_code=429, + detail="You have reached the maximum number of requests for today.", + ) # create new chat @chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"]) async def create_chat_handler( - request: Request, - commons: CommonsDep, - chat_message: ChatMessage, + chat_data: CreateChatProperties, current_user: User = Depends(get_current_user), ): """ Create a new chat with initial chat messages. """ - return chat_handler( - request, commons, None, chat_message, current_user.email, is_new_chat=True - ) + + commons = common_dependencies() + user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email}) + return create_chat(user_id=user_id, chat_data=chat_data) + + +# add new question to chat +@chat_router.post( + "/chat/{chat_id}/question", dependencies=[Depends(AuthBearer())], tags=["Chat"] +) +async def create_question_handler( + request: Request, + chat_question: ChatQuestion, + chat_id: UUID, + current_user: User = Depends(get_current_user), +) -> ChatHistory: + try: + check_user_limit(current_user.email) + user_openai_api_key = request.headers.get("Openai-Api-Key") + openai_function_compatible_models = [ + "gpt-3.5-turbo-0613", + "gpt-4-0613", + ] + if chat_question.model in openai_function_compatible_models: + # TODO: RBAC with current_user + gpt_answer_generator = OpenAiFunctionBasedAnswerGenerator( + model=chat_question.model, + chat_id=chat_id, + temperature=chat_question.temperature, + max_tokens=chat_question.max_tokens, + # TODO: use user_id in vectors table instead of email + user_email=current_user.email, + user_openai_api_key=user_openai_api_key, + ) + answer = gpt_answer_generator.get_answer(chat_question.question) + else: + brainPicking = BrainPicking( + chat_id=str(chat_id), + model=chat_question.model, + max_tokens=chat_question.max_tokens, + user_id=current_user.email, + user_openai_api_key=user_openai_api_key, + ) + answer = brainPicking.generate_answer(chat_question.question) + + chat_answer = update_chat_history( + chat_id=chat_id, + user_message=chat_question.question, + assistant_answer=answer, + ) + return chat_answer + except HTTPException as e: + raise e + + +# get chat history +@chat_router.get( + "/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"] +) +async def get_chat_history_handler( + chat_id: UUID, + current_user: User = Depends(get_current_user), +) -> list[ChatHistory]: + # TODO: RBAC with current_user + return get_chat_history(chat_id) diff --git a/backend/routes/explore_routes.py b/backend/routes/explore_routes.py index 054861fa5..a7099b68b 100644 --- a/backend/routes/explore_routes.py +++ b/backend/routes/explore_routes.py @@ -5,49 +5,73 @@ from models.users import User explore_router = APIRouter() + def get_unique_user_data(commons, user): """ Retrieve unique user data vectors. """ - response = commons['supabase'].table("vectors").select( - "name:metadata->>file_name, size:metadata->>file_size", count="exact").filter("user_id", "eq", user.email).execute() + response = ( + commons["supabase"] + .table("vectors") + .select("name:metadata->>file_name, size:metadata->>file_size", count="exact") + .filter("user_id", "eq", user.email) + .execute() + ) documents = response.data # Access the data from the response # Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)] return unique_data + @explore_router.get("/explore", dependencies=[Depends(AuthBearer())], tags=["Explore"]) -async def explore_endpoint( current_user: User = Depends(get_current_user)): +async def explore_endpoint(current_user: User = Depends(get_current_user)): """ Retrieve and explore unique user data vectors. """ commons = common_dependencies() unique_data = get_unique_user_data(commons, current_user) - unique_data.sort(key=lambda x: int(x['size']), reverse=True) + unique_data.sort(key=lambda x: int(x["size"]), reverse=True) return {"documents": unique_data} -@explore_router.delete("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]) -async def delete_endpoint( file_name: str, credentials: dict = Depends(AuthBearer())): +@explore_router.delete( + "/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"] +) +async def delete_endpoint(file_name: str, credentials: dict = Depends(AuthBearer())): """ Delete a specific user file by file name. """ commons = common_dependencies() - user = User(email=credentials.get('email', 'none')) + user = User(email=credentials.get("email", "none")) # Cascade delete the summary from the database first, because it has a foreign key constraint - commons['supabase'].table("summaries").delete().match( - {"metadata->>file_name": file_name}).execute() - commons['supabase'].table("vectors").delete().match( - {"metadata->>file_name": file_name, "user_id": user.email}).execute() + commons["supabase"].table("summaries").delete().match( + {"metadata->>file_name": file_name} + ).execute() + commons["supabase"].table("vectors").delete().match( + {"metadata->>file_name": file_name, "user_id": user.email} + ).execute() return {"message": f"{file_name} of user {user.email} has been deleted."} -@explore_router.get("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]) -async def download_endpoint( file_name: str, current_user: User = Depends(get_current_user)): + +@explore_router.get( + "/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"] +) +async def download_endpoint( + file_name: str, current_user: User = Depends(get_current_user) +): """ Download a specific user file by file name. """ commons = common_dependencies() - response = commons['supabase'].table("vectors").select( - "metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute() + response = ( + commons["supabase"] + .table("vectors") + .select( + "metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", + "content", + ) + .match({"metadata->>file_name": file_name, "user_id": current_user.email}) + .execute() + ) documents = response.data return {"documents": documents} diff --git a/backend/utils/chats.py b/backend/utils/chats.py index c6d44bd3c..228934393 100644 --- a/backend/utils/chats.py +++ b/backend/utils/chats.py @@ -1,48 +1,10 @@ from logger import get_logger from models.chats import ChatMessage -from models.settings import CommonsDep +from models.settings import common_dependencies logger = get_logger(__name__) -def create_chat(commons: CommonsDep, user_id, history, chat_name): - # Chat is created upon the user's first question asked - logger.info(f"New chat entry in chats table for user {user_id}") - - # Insert a new row into the chats table - new_chat = { - "user_id": user_id, - "history": history, # Empty chat to start - "chat_name": chat_name, - } - insert_response = commons["supabase"].table("chats").insert(new_chat).execute() - logger.info(f"Insert response {insert_response.data}") - - return insert_response - - -def update_chat(commons: CommonsDep, chat_id, history=None, chat_name=None): - if not chat_id: - logger.error("No chat_id provided") - return - - updates = {} - - if history is not None: - updates["history"] = history - - if chat_name is not None: - updates["chat_name"] = chat_name - - if updates: - commons["supabase"].table("chats").update(updates).match( - {"chat_id": chat_id} - ).execute() - logger.info(f"Chat {chat_id} updated") - else: - logger.info(f"No updates to apply for chat {chat_id}") - - def get_chat_name_from_first_question(chat_message: ChatMessage): # Step 1: Get the summary of the first question # first_question_summary = summarize_as_title(chat_message.question) diff --git a/backend/utils/users.py b/backend/utils/users.py index 09e1e90ba..1651e60f3 100644 --- a/backend/utils/users.py +++ b/backend/utils/users.py @@ -6,32 +6,45 @@ from models.users import User logger = get_logger(__name__) + def create_user(commons: CommonsDep, email, date): logger.info(f"New user entry in db document for user {email}") - return(commons['supabase'].table("users").insert( - {"email": email, "date": date, "requests_count": 1}).execute()) + return ( + commons["supabase"] + .table("users") + .insert({"email": email, "date": date, "requests_count": 1}) + .execute() + ) + def update_user_request_count(commons: CommonsDep, email, date, requests_count): logger.info(f"User {email} request count updated to {requests_count}") - commons['supabase'].table("users").update( - { "requests_count": requests_count}).match({"email": email, "date": date}).execute() - -def fetch_user_id_from_credentials(commons: CommonsDep,credentials): - user = User(email=credentials.get('email', 'none')) + commons["supabase"].table("users").update({"requests_count": requests_count}).match( + {"email": email, "date": date} + ).execute() + + +def fetch_user_id_from_credentials(commons: CommonsDep, credentials): + user = User(email=credentials.get("email", "none")) # Fetch the user's UUID based on their email - response = commons['supabase'].from_('users').select('user_id').filter("email", "eq", user.email).execute() + response = ( + commons["supabase"] + .from_("users") + .select("user_id") + .filter("email", "eq", user.email) + .execute() + ) userItem = next(iter(response.data or []), {}) - if userItem == {}: + if userItem == {}: date = time.strftime("%Y%m%d") - create_user_response = create_user(commons, email= user.email, date=date) - user_id = create_user_response.data[0]['user_id'] + create_user_response = create_user(commons, email=user.email, date=date) + user_id = create_user_response.data[0]["user_id"] - else: - user_id = userItem['user_id'] + else: + user_id = userItem["user_id"] return user_id - diff --git a/backend/utils/vectors.py b/backend/utils/vectors.py index 02e31bec2..05a5fd837 100644 --- a/backend/utils/vectors.py +++ b/backend/utils/vectors.py @@ -9,32 +9,44 @@ from pydantic import BaseModel logger = get_logger(__name__) -class Neurons(BaseModel): +class Neurons(BaseModel): commons: CommonsDep settings = BrainSettings() - + def create_vector(self, user_id, doc, user_openai_api_key=None): logger.info(f"Creating vector for document") logger.info(f"Document: {doc}") if user_openai_api_key: - self.commons['documents_vector_store']._embedding = OpenAIEmbeddings(openai_api_key=user_openai_api_key) + self.commons["documents_vector_store"]._embedding = OpenAIEmbeddings( + openai_api_key=user_openai_api_key + ) try: - sids = self.commons['documents_vector_store'].add_documents([doc]) + sids = self.commons["documents_vector_store"].add_documents([doc]) if sids and len(sids) > 0: - self.commons['supabase'].table("vectors").update({"user_id": user_id}).match({"id": sids[0]}).execute() + self.commons["supabase"].table("vectors").update( + {"user_id": user_id} + ).match({"id": sids[0]}).execute() except Exception as e: logger.error(f"Error creating vector for document {e}") def create_embedding(self, content): - return self.commons['embeddings'].embed_query(content) + return self.commons["embeddings"].embed_query(content) - def similarity_search(self, query, table='match_summaries', top_k=5, threshold=0.5): + def similarity_search(self, query, table="match_summaries", top_k=5, threshold=0.5): query_embedding = self.create_embedding(query) - summaries = self.commons['supabase'].rpc( - table, {'query_embedding': query_embedding, - 'match_count': top_k, 'match_threshold': threshold} - ).execute() + summaries = ( + self.commons["supabase"] + .rpc( + table, + { + "query_embedding": query_embedding, + "match_count": top_k, + "match_threshold": threshold, + }, + ) + .execute() + ) return summaries.data @@ -42,11 +54,10 @@ def create_summary(commons: CommonsDep, document_id, content, metadata): logger.info(f"Summarizing document {content[:100]}") summary = llm_summerize(content) logger.info(f"Summary: {summary}") - metadata['document_id'] = document_id - summary_doc_with_metadata = Document( - page_content=summary, metadata=metadata) - sids = commons['summaries_vector_store'].add_documents( - [summary_doc_with_metadata]) + metadata["document_id"] = document_id + summary_doc_with_metadata = Document(page_content=summary, metadata=metadata) + sids = commons["summaries_vector_store"].add_documents([summary_doc_with_metadata]) if sids and len(sids) > 0: - commons['supabase'].table("summaries").update( - {"document_id": document_id}).match({"id": sids[0]}).execute() + commons["supabase"].table("summaries").update( + {"document_id": document_id} + ).match({"id": sids[0]}).execute() diff --git a/backend/vectorstore/supabase.py b/backend/vectorstore/supabase.py index 0b2992cc9..a2146c7b5 100644 --- a/backend/vectorstore/supabase.py +++ b/backend/vectorstore/supabase.py @@ -7,18 +7,26 @@ from supabase import Client class CustomSupabaseVectorStore(SupabaseVectorStore): - '''A custom vector store that uses the match_vectors table instead of the vectors table.''' + """A custom vector store that uses the match_vectors table instead of the vectors table.""" + user_id: str - def __init__(self, client: Client, embedding: OpenAIEmbeddings, table_name: str, user_id: str = "none"): + + def __init__( + self, + client: Client, + embedding: OpenAIEmbeddings, + table_name: str, + user_id: str = "none", + ): super().__init__(client, embedding, table_name) self.user_id = user_id - + def similarity_search( - self, - query: str, - table: str = "match_vectors", - k: int = 6, - threshold: float = 0.5, + self, + query: str, + table: str = "match_vectors", + k: int = 6, + threshold: float = 0.5, **kwargs: Any ) -> List[Document]: vectors = self._embedding.embed_documents([query]) @@ -46,4 +54,4 @@ class CustomSupabaseVectorStore(SupabaseVectorStore): documents = [doc for doc, _ in match_result] - return documents \ No newline at end of file + return documents diff --git a/docs/docs/backend/api/chat.md b/docs/docs/backend/api/chat.md new file mode 100644 index 000000000..a8e428957 --- /dev/null +++ b/docs/docs/backend/api/chat.md @@ -0,0 +1,51 @@ +--- +sidebar_position: 2 +--- + +# Chat system + +**URL**: https://api.quivr.app/chat + +**Swagger**: https://api.quivr.app/docs + +## Overview + +Users can create multiple chat sessions, each with its own set of chat messages. The application provides endpoints to perform various operations such as retrieving all chats for the current user, deleting specific chats, updating chat attributes, creating new chats with initial messages, adding new questions to existing chats, and retrieving the chat history. These features enable users to communicate and interact with their data in a conversational manner. + +1. **Retrieve all chats for the current user:** + + - HTTP method: GET + - Endpoint: `/chat` + - Description: This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects containing the chat ID and chat name for each chat. + +2. **Delete a specific chat by chat ID:** + + - HTTP method: DELETE + - Endpoint: `/chat/{chat_id}` + - Description: This endpoint allows deleting a specific chat identified by its chat ID. + +3. **Update chat attributes:** + + - HTTP method: PUT + - Endpoint: `/chat/{chat_id}/metadata` + - Description: This endpoint is used to update the attributes of a chat, such as the chat name. + +4. **Create a new chat with initial chat messages:** + + - HTTP method: POST + - Endpoint: `/chat` + - Description: This endpoint creates a new chat with initial chat messages. It expects the chat name in the request payload. + +5. **Add a new question to a chat:** + + - HTTP method: POST + - Endpoint: `/chat/{chat_id}/question` + - Description: This endpoint allows adding a new question to a chat. It generates an answer for the question using different models based on the provided model type. + + Models like gpt-4-0613 and gpt-3.5-turbo-0613 use a custom OpenAI function-based answer generator. + ![Function based answer generator](../../../static/img/answer_schema.png) + +6. **Get the chat history:** + - HTTP method: GET + - Endpoint: `/chat/{chat_id}/history` + - Description: This endpoint retrieves the chat history for a specific chat identified by its chat ID. diff --git a/docs/static/img/answer_schema.png b/docs/static/img/answer_schema.png new file mode 100644 index 000000000..7fe109292 Binary files /dev/null and b/docs/static/img/answer_schema.png differ diff --git a/frontend/app/chat/[chatId]/context/ChatContext.tsx b/frontend/app/chat/[chatId]/context/ChatContext.tsx new file mode 100644 index 000000000..4eb24a57b --- /dev/null +++ b/frontend/app/chat/[chatId]/context/ChatContext.tsx @@ -0,0 +1,48 @@ +"use client"; + +import { createContext, useContext, useState } from "react"; + +import { ChatHistory } from "../types"; + +type ChatContextProps = { + history: ChatHistory[]; + setHistory: (history: ChatHistory[]) => void; + addToHistory: (message: ChatHistory) => void; +}; + +export const ChatContext = createContext( + undefined +); + +export const ChatProvider = ({ + children, +}: { + children: JSX.Element | JSX.Element[]; +}): JSX.Element => { + const [history, setHistory] = useState([]); + const addToHistory = (message: ChatHistory) => { + setHistory((prevHistory) => [...prevHistory, message]); + }; + + return ( + + {children} + + ); +}; + +export const useChatContext = (): ChatContextProps => { + const context = useContext(ChatContext); + + if (context === undefined) { + throw new Error("useChatContext must be used inside ChatProvider"); + } + + return context; +}; diff --git a/frontend/app/chat/[chatId]/hooks/useChat.ts b/frontend/app/chat/[chatId]/hooks/useChat.ts new file mode 100644 index 000000000..e431edc1f --- /dev/null +++ b/frontend/app/chat/[chatId]/hooks/useChat.ts @@ -0,0 +1,89 @@ +import { AxiosError } from "axios"; +import { useParams } from "next/navigation"; +import { useEffect, useState } from "react"; + +import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig"; +import { useToast } from "@/lib/hooks"; + +import { useChatService } from "./useChatService"; +import { useChatContext } from "../context/ChatContext"; +import { ChatQuestion } from "../types"; + +// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types +export const useChat = () => { + const params = useParams(); + const [chatId, setChatId] = useState( + params?.chatId as string | undefined + ); + const [generatingAnswer, setGeneratingAnswer] = useState(false); + const { + config: { maxTokens, model, temperature }, + } = useBrainConfig(); + const { history, setHistory, addToHistory } = useChatContext(); + const { publish } = useToast(); + + const { + createChat, + getChatHistory, + addQuestion: addQuestionToChat, + } = useChatService(); + + useEffect(() => { + const fetchHistory = async () => { + const chatHistory = await getChatHistory(chatId); + setHistory(chatHistory); + }; + void fetchHistory(); + }, [chatId]); + + const generateNewChatIdFromName = async ( + chatName: string + ): Promise => { + const rep = await createChat({ name: chatName }); + setChatId(rep.data.chat_id); + + return rep.data.chat_id; + }; + + const addQuestion = async (question: string, callback?: () => void) => { + const chatQuestion: ChatQuestion = { + model, + question, + temperature, + max_tokens: maxTokens, + }; + + try { + setGeneratingAnswer(true); + const currentChatId = + chatId ?? + // if chatId is undefined, we need to create a new chat on fly + (await generateNewChatIdFromName( + question.split(" ").slice(0, 3).join(" ") + )); + const answer = await addQuestionToChat(currentChatId, chatQuestion); + addToHistory(answer); + callback?.(); + } catch (error) { + console.error({ error }); + + if ((error as AxiosError).response?.status === 429) { + publish({ + variant: "danger", + text: "You have reached the limit of requests, please try again later", + }); + + return; + } + + publish({ + variant: "danger", + text: "Error occurred while getting answer", + }); + } finally { + setGeneratingAnswer(false); + } + }; + + return { history, addQuestion, generatingAnswer }; +}; diff --git a/frontend/app/chat/[chatId]/hooks/useChatService.ts b/frontend/app/chat/[chatId]/hooks/useChatService.ts new file mode 100644 index 000000000..ead69caac --- /dev/null +++ b/frontend/app/chat/[chatId]/hooks/useChatService.ts @@ -0,0 +1,40 @@ +import { useAxios } from "@/lib/hooks"; + +import { ChatEntity, ChatHistory, ChatQuestion } from "../types"; + +// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types +export const useChatService = () => { + const { axiosInstance } = useAxios(); + + const createChat = async ({ name }: { name: string }) => { + return axiosInstance.post(`/chat`, { name }); + }; + + const getChatHistory = async (chatId: string | undefined) => { + if (chatId === undefined) { + return []; + } + const rep = ( + await axiosInstance.get(`/chat/${chatId}/history`) + ).data; + + return rep; + }; + const addQuestion = async ( + chatId: string, + chatQuestion: ChatQuestion + ): Promise => { + return ( + await axiosInstance.post( + `/chat/${chatId}/question`, + chatQuestion + ) + ).data; + }; + + return { + createChat, + getChatHistory, + addQuestion, + }; +}; diff --git a/frontend/app/chat/[chatId]/page.tsx b/frontend/app/chat/[chatId]/page.tsx index 3980fd944..3450d4fb1 100644 --- a/frontend/app/chat/[chatId]/page.tsx +++ b/frontend/app/chat/[chatId]/page.tsx @@ -1,31 +1,12 @@ /* eslint-disable */ "use client"; -import { UUID } from "crypto"; -import { useEffect } from "react"; import PageHeading from "@/lib/components/ui/PageHeading"; -import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext"; import { ChatInput, ChatMessages } from "../components"; +import { ChatProvider } from "./context/ChatContext"; -interface ChatPageProps { - params: { - chatId: UUID; - }; -} - -export default function ChatPage({ params }: ChatPageProps) { - const chatId: UUID | undefined = params.chatId; - - const { fetchChat, resetChat } = useChatsContext(); - - useEffect(() => { - if (!chatId) { - resetChat(); - } - fetchChat(chatId); - }, []); - +export default function ChatPage() { return (
@@ -33,12 +14,14 @@ export default function ChatPage({ params }: ChatPageProps) { title="Chat with your brain" subtitle="Talk to a language model about your uploaded data" /> -
-
- + +
+
+ +
+
- -
+
); diff --git a/frontend/app/chat/[chatId]/types/index.ts b/frontend/app/chat/[chatId]/types/index.ts new file mode 100644 index 000000000..37d9f0d32 --- /dev/null +++ b/frontend/app/chat/[chatId]/types/index.ts @@ -0,0 +1,22 @@ +import { UUID } from "crypto"; + +export type ChatQuestion = { + model: string; + question?: string; + temperature: number; + max_tokens: number; +}; +export type ChatHistory = { + chat_id: string; + message_id: string; + user_message: string; + assistant: string; + message_time: string; +}; + +export type ChatEntity = { + chat_id: UUID; + user_id: string; + creation_time: string; + chat_name: string; +}; diff --git a/frontend/app/chat/components/ChatMessages/ChatInput/MicButton.tsx b/frontend/app/chat/components/ChatMessages/ChatInput/MicButton.tsx index 99e8697b0..aa808cd36 100644 --- a/frontend/app/chat/components/ChatMessages/ChatInput/MicButton.tsx +++ b/frontend/app/chat/components/ChatMessages/ChatInput/MicButton.tsx @@ -5,8 +5,14 @@ import { MdMic, MdMicOff } from "react-icons/md"; import Button from "@/lib/components/ui/Button"; import { useSpeech } from "@/lib/context/ChatsProvider/hooks/useSpeech"; -export const MicButton = (): JSX.Element => { - const { isListening, speechSupported, startListening } = useSpeech(); +type MicButtonProps = { + setMessage: (newValue: string | ((prevValue: string) => string)) => void; +}; + +export const MicButton = ({ setMessage }: MicButtonProps): JSX.Element => { + const { isListening, speechSupported, startListening } = useSpeech({ + setMessage, + }); return (