diff --git a/backend/core/llm/qa_base.py b/backend/core/llm/qa_base.py index 54d0b24df..fc3a3db37 100644 --- a/backend/core/llm/qa_base.py +++ b/backend/core/llm/qa_base.py @@ -6,22 +6,24 @@ from uuid import UUID from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain -from logger import get_logger -from models.chat import ChatHistory -from langchain.llms.base import BaseLLM from langchain.chat_models import ChatOpenAI +from langchain.llms.base import BaseLLM +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from logger import get_logger +from models.chats import ChatQuestion +from models.databases.supabase.chats import CreateChatHistory from repository.brain.get_brain_by_id import get_brain_by_id +from repository.brain.get_brain_prompt_id import get_brain_prompt_id from repository.chat.format_chat_history import format_chat_history -from repository.chat.get_chat_history import get_chat_history +from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history from repository.chat.update_chat_history import update_chat_history from repository.chat.update_message_by_id import update_message_by_id from repository.prompt.get_prompt_by_id import get_prompt_by_id from supabase.client import Client, create_client -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate -) from vectorstore.supabase import CustomSupabaseVectorStore from .base import BaseBrainPicking @@ -39,6 +41,7 @@ class QABaseBrainPicking(BaseBrainPicking): Both are the same, except that the streaming version streams the last message as a stream. Each have the same prompt template, which is defined in the `prompt_template` property. """ + supabase_client: Client = None vector_store: CustomSupabaseVectorStore = None qa: ConversationalRetrievalChain = None @@ -61,8 +64,6 @@ class QABaseBrainPicking(BaseBrainPicking): self.supabase_client = self._create_supabase_client() self.vector_store = self._create_vector_store() - - def _create_supabase_client(self) -> Client: return create_client( self.brain_settings.supabase_url, self.brain_settings.supabase_service_key @@ -76,7 +77,9 @@ class QABaseBrainPicking(BaseBrainPicking): brain_id=self.brain_id, ) - def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM: + def _create_llm( + self, model, temperature=0, streaming=False, callbacks=None + ) -> BaseLLM: """ Determine the language model to be used. :param model: Language model name to be used. @@ -94,13 +97,17 @@ class QABaseBrainPicking(BaseBrainPicking): ) # pyright: ignore reportPrivateUsage=none def _create_prompt_template(self): - system_template = """You can use Markdown to make your answers nice. Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way. ---------------- {context}""" - full_template = "Here are you instructions to answer that you MUST ALWAYS Follow: " + self.get_prompt() + ". " + system_template + full_template = ( + "Here are you instructions to answer that you MUST ALWAYS Follow: " + + self.get_prompt() + + ". " + + system_template + ) messages = [ SystemMessagePromptTemplate.from_template(full_template), HumanMessagePromptTemplate.from_template("{question}"), @@ -108,13 +115,18 @@ class QABaseBrainPicking(BaseBrainPicking): CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) return CHAT_PROMPT - - def generate_answer(self, question: str) -> ChatHistory: + def generate_answer( + self, chat_id: UUID, question: ChatQuestion + ) -> GetChatHistoryOutput: transformed_history = format_chat_history(get_chat_history(self.chat_id)) - answering_llm = self._create_llm(model=self.model,streaming=False, callbacks=self.callbacks) + answering_llm = self._create_llm( + model=self.model, streaming=False, callbacks=self.callbacks + ) # The Chain that generates the answer to the question - doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template()) + doc_chain = load_qa_chain( + answering_llm, chain_type="stuff", prompt=self._create_prompt_template() + ) # The Chain that combines the question and answer qa = ConversationalRetrievalChain( @@ -126,28 +138,68 @@ class QABaseBrainPicking(BaseBrainPicking): verbose=True, ) - model_response = qa({ + model_response = qa( + { "question": question, "chat_history": transformed_history, "custom_personality": self.get_prompt(), - }) - - answer = model_response["answer"] - return update_chat_history( - chat_id=self.chat_id, - user_message=question, - assistant=answer, + } ) - async def generate_stream(self, question: str) -> AsyncIterable: + answer = model_response["answer"] + + prompt_id = ( + get_brain_prompt_id(question.brain_id) if question.brain_id else None + ) + new_chat = update_chat_history( + CreateChatHistory( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": answer, + "brain_id": question.brain_id, + "prompt_id": prompt_id, + } + ) + ) + + brain = None + prompt = None + prompt_id = None + + if question.brain_id: + brain = get_brain_by_id(question.brain_id) + if brain and brain.prompt_id: + prompt = get_prompt_by_id(brain.prompt_id) + prompt_id = prompt.id if prompt else None + + return GetChatHistoryOutput( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": "", + "message_time": new_chat.message_time, + "prompt_title": prompt.title if prompt else None, + "brain_name": brain.name if brain else None, + "message_id": new_chat.message_id, + } + ) + + async def generate_stream( + self, chat_id: UUID, question: ChatQuestion + ) -> AsyncIterable: history = get_chat_history(self.chat_id) callback = AsyncIteratorCallbackHandler() self.callbacks = [callback] - answering_llm = self._create_llm(model=self.model,streaming=True, callbacks=self.callbacks) + answering_llm = self._create_llm( + model=self.model, streaming=True, callbacks=self.callbacks + ) # The Chain that generates the answer to the question - doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template()) + doc_chain = load_qa_chain( + answering_llm, chain_type="stuff", prompt=self._create_prompt_template() + ) # The Chain that combines the question and answer qa = ConversationalRetrievalChain( @@ -184,24 +236,52 @@ class QABaseBrainPicking(BaseBrainPicking): ) ) + brain = None + prompt = None + prompt_id = None + + if question.brain_id: + brain = get_brain_by_id(question.brain_id) + if brain and brain.prompt_id: + prompt = get_prompt_by_id(brain.prompt_id) + prompt_id = prompt.id if prompt else None + streamed_chat_history = update_chat_history( - chat_id=self.chat_id, - user_message=question, - assistant="", + CreateChatHistory( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": "", + "brain_id": question.brain_id, + "prompt_id": prompt_id, + } + ) + ) + + streamed_chat_history = GetChatHistoryOutput( + **{ + "chat_id": str(chat_id), + "message_id": streamed_chat_history.message_id, + "message_time": streamed_chat_history.message_time, + "user_message": question.question, + "assistant": "", + "prompt_title": prompt.title if prompt else None, + "brain_name": brain.name if brain else None, + } ) async for token in callback.aiter(): logger.info("Token: %s", token) response_tokens.append(token) streamed_chat_history.assistant = token - yield f"data: {json.dumps(streamed_chat_history.to_dict())}" + yield f"data: {json.dumps(streamed_chat_history.dict())}" await run assistant = "".join(response_tokens) update_message_by_id( - message_id=streamed_chat_history.message_id, - user_message=question, + message_id=str(streamed_chat_history.message_id), + user_message=question.question, assistant=assistant, ) diff --git a/backend/core/models/chat.py b/backend/core/models/chat.py index 1634b55dd..a2328c7d4 100644 --- a/backend/core/models/chat.py +++ b/backend/core/models/chat.py @@ -1,4 +1,6 @@ from dataclasses import asdict, dataclass +from typing import Optional +from uuid import UUID @dataclass @@ -9,18 +11,10 @@ class Chat: chat_name: str def __init__(self, chat_dict: dict): - self.chat_id = chat_dict.get( - "chat_id" - ) # pyright: ignore reportPrivateUsage=none - self.user_id = chat_dict.get( - "user_id" - ) # pyright: ignore reportPrivateUsage=none - self.creation_time = chat_dict.get( - "creation_time" - ) # pyright: ignore reportPrivateUsage=none - self.chat_name = chat_dict.get( - "chat_name" - ) # pyright: ignore reportPrivateUsage=none + self.chat_id = chat_dict.get("chat_id", "") + self.user_id = chat_dict.get("user_id", "") + self.creation_time = chat_dict.get("creation_time", "") + self.chat_name = chat_dict.get("chat_name", "") @dataclass @@ -30,23 +24,18 @@ class ChatHistory: user_message: str assistant: str message_time: str + prompt_id: Optional[UUID] + brain_id: Optional[UUID] def __init__(self, chat_dict: dict): - self.chat_id = chat_dict.get( - "chat_id" - ) # pyright: ignore reportPrivateUsage=none - self.message_id = chat_dict.get( - "message_id" - ) # pyright: ignore reportPrivateUsage=none - self.user_message = chat_dict.get( - "user_message" - ) # pyright: ignore reportPrivateUsage=none - self.assistant = chat_dict.get( - "assistant" - ) # pyright: ignore reportPrivateUsage=none - self.message_time = chat_dict.get( - "message_time" - ) # pyright: ignore reportPrivateUsage=none + self.chat_id = chat_dict.get("chat_id", "") + self.message_id = chat_dict.get("message_id", "") + self.user_message = chat_dict.get("user_message", "") + self.assistant = chat_dict.get("assistant", "") + self.message_time = chat_dict.get("message_time", "") + + self.prompt_id = chat_dict.get("prompt_id") + self.brain_id = chat_dict.get("brain_id") def to_dict(self): return asdict(self) diff --git a/backend/core/models/chats.py b/backend/core/models/chats.py index e1cbb0f99..ecc643b9b 100644 --- a/backend/core/models/chats.py +++ b/backend/core/models/chats.py @@ -21,3 +21,4 @@ class ChatQuestion(BaseModel): question: str temperature: float = 0.0 max_tokens: int = 256 + brain_id: Optional[UUID] diff --git a/backend/core/models/databases/supabase/chats.py b/backend/core/models/databases/supabase/chats.py index ea7ccd719..7952eed86 100644 --- a/backend/core/models/databases/supabase/chats.py +++ b/backend/core/models/databases/supabase/chats.py @@ -1,4 +1,16 @@ +from typing import Optional +from uuid import UUID + from models.databases.repository import Repository +from pydantic import BaseModel + + +class CreateChatHistory(BaseModel): + chat_id: UUID + user_message: str + assistant: str + prompt_id: Optional[UUID] + brain_id: Optional[UUID] class Chats(Repository): @@ -38,14 +50,20 @@ class Chats(Repository): ) return response - def update_chat_history(self, chat_id: str, user_message: str, assistant: str): + def update_chat_history(self, chat_history: CreateChatHistory): response = ( self.db.table("chat_history") .insert( { - "chat_id": str(chat_id), - "user_message": user_message, - "assistant": assistant, + "chat_id": str(chat_history.chat_id), + "user_message": chat_history.user_message, + "assistant": chat_history.assistant, + "prompt_id": str(chat_history.prompt_id) + if chat_history.prompt_id + else None, + "brain_id": str(chat_history.brain_id) + if chat_history.brain_id + else None, } ) .execute() diff --git a/backend/core/repository/brain/get_brain_prompt_id.py b/backend/core/repository/brain/get_brain_prompt_id.py new file mode 100644 index 000000000..7ea844e46 --- /dev/null +++ b/backend/core/repository/brain/get_brain_prompt_id.py @@ -0,0 +1,10 @@ +from uuid import UUID + +from repository.brain.get_brain_by_id import get_brain_by_id + + +def get_brain_prompt_id(brain_id: UUID) -> UUID | None: + brain = get_brain_by_id(brain_id) + prompt_id = brain.brain_id if brain else None + + return prompt_id diff --git a/backend/core/repository/chat/get_chat_history.py b/backend/core/repository/chat/get_chat_history.py index 4760ef210..6a11fe488 100644 --- a/backend/core/repository/chat/get_chat_history.py +++ b/backend/core/repository/chat/get_chat_history.py @@ -1,16 +1,57 @@ -from typing import List +from typing import List, Optional +from uuid import UUID from models.chat import ChatHistory from models.settings import get_supabase_db # For type hinting +from pydantic import BaseModel + +from repository.brain.get_brain_by_id import get_brain_by_id +from repository.prompt.get_prompt_by_id import get_prompt_by_id -def get_chat_history(chat_id: str) -> List[ChatHistory]: +class GetChatHistoryOutput(BaseModel): + chat_id: UUID + message_id: UUID + user_message: str + assistant: str + message_time: str + prompt_title: Optional[str] | None + brain_name: Optional[str] | None + + def dict(self, *args, **kwargs): + chat_history = super().dict(*args, **kwargs) + chat_history["chat_id"] = str(chat_history.get("prompt_id")) + chat_history["message_id"] = str(chat_history.get("message_id")) + + return chat_history + + +def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]: supabase_db = get_supabase_db() - history: List[ChatHistory] = supabase_db.get_chat_history(chat_id).data + history: List[dict] = supabase_db.get_chat_history(chat_id).data if history is None: return [] else: - return [ - ChatHistory(message) # pyright: ignore reportPrivateUsage=none - for message in history - ] + enriched_history: List[GetChatHistoryOutput] = [] + for message in history: + message = ChatHistory(message) + brain = None + if message.brain_id: + brain = get_brain_by_id(message.brain_id) + + prompt = None + if message.prompt_id: + prompt = get_prompt_by_id(message.prompt_id) + + enriched_history.append( + GetChatHistoryOutput( + chat_id=(UUID(message.chat_id)), + message_id=(UUID(message.message_id)), + user_message=message.user_message, + assistant=message.assistant, + message_time=message.message_time, + brain_name=brain.name if brain else None, + prompt_title=prompt.title if prompt else None, + ) + ) + return enriched_history diff --git a/backend/core/repository/chat/update_chat_history.py b/backend/core/repository/chat/update_chat_history.py index 8c49bc2e5..3f948cf7a 100644 --- a/backend/core/repository/chat/update_chat_history.py +++ b/backend/core/repository/chat/update_chat_history.py @@ -1,15 +1,14 @@ -from typing import List # For type hinting +from typing import List from fastapi import HTTPException from models.chat import ChatHistory +from models.databases.supabase.chats import CreateChatHistory from models.settings import get_supabase_db -def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory: +def update_chat_history(chat_history: CreateChatHistory) -> ChatHistory: supabase_db = get_supabase_db() - response: List[ChatHistory] = ( - supabase_db.update_chat_history(chat_id, user_message, assistant) - ).data + response: List[ChatHistory] = (supabase_db.update_chat_history(chat_history)).data if len(response) == 0: raise HTTPException( status_code=500, detail="An exception occurred while updating chat history." diff --git a/backend/core/routes/chat_routes.py b/backend/core/routes/chat_routes.py index b6535fdee..2a6837abe 100644 --- a/backend/core/routes/chat_routes.py +++ b/backend/core/routes/chat_routes.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse from llm.openai import OpenAIBrainPicking from models.brains import Brain -from models.chat import Chat, ChatHistory +from models.chat import Chat from models.chats import ChatQuestion from models.databases.supabase.supabase import SupabaseDB from models.settings import LLMSettings, get_supabase_db @@ -20,7 +20,7 @@ from repository.brain.get_default_user_brain_or_create_new import ( ) from repository.chat.create_chat import CreateChatProperties, create_chat from repository.chat.get_chat_by_id import get_chat_by_id -from repository.chat.get_chat_history import get_chat_history +from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history from repository.chat.get_user_chats import get_user_chats from repository.chat.update_chat import ChatUpdatableProperties, update_chat from repository.user_identity.get_user_identity import get_user_identity @@ -85,7 +85,7 @@ async def get_chats(current_user: User = Depends(get_current_user)): This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects containing the chat ID and chat name for each chat. """ - chats = get_user_chats(current_user.id) # pyright: ignore reportPrivateUsage=none + chats = get_user_chats(str(current_user.id)) return {"chats": chats} @@ -155,7 +155,7 @@ async def create_question_handler( | UUID | None = Query(..., description="The ID of the brain"), current_user: User = Depends(get_current_user), -) -> ChatHistory: +) -> GetChatHistoryOutput: """ Add a new question to the chat. """ @@ -163,11 +163,10 @@ async def create_question_handler( current_user.user_openai_api_key = request.headers.get("Openai-Api-Key") brain = Brain(id=brain_id) - if not current_user.user_openai_api_key: - if brain_id: - brain_details = get_brain_details(brain_id) - if brain_details: - current_user.user_openai_api_key = brain_details.openai_api_key + if not current_user.user_openai_api_key and brain_id: + brain_details = get_brain_details(brain_id) + if brain_details: + current_user.user_openai_api_key = brain_details.openai_api_key if not current_user.user_openai_api_key: user_identity = get_user_identity(current_user.id) @@ -202,9 +201,7 @@ async def create_question_handler( user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none ) - chat_answer = gpt_answer_generator.generate_answer( # pyright: ignore reportPrivateUsage=none - chat_question.question - ) + chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question) return chat_answer except HTTPException as e: @@ -276,9 +273,7 @@ async def create_stream_question_handler( print("streaming") return StreamingResponse( - gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none - chat_question.question - ), + gpt_answer_generator.generate_stream(chat_id, chat_question), media_type="text/event-stream", ) @@ -292,6 +287,6 @@ async def create_stream_question_handler( ) async def get_chat_history_handler( chat_id: UUID, -) -> List[ChatHistory]: +) -> List[GetChatHistoryOutput]: # TODO: RBAC with current_user - return get_chat_history(chat_id) # pyright: ignore reportPrivateUsage=none + return get_chat_history(str(chat_id)) diff --git a/scripts/20230809154300_add_prompt_id_brain_id_to_chat_history_table.sql b/scripts/20230809154300_add_prompt_id_brain_id_to_chat_history_table.sql new file mode 100644 index 000000000..04f3044e6 --- /dev/null +++ b/scripts/20230809154300_add_prompt_id_brain_id_to_chat_history_table.sql @@ -0,0 +1,28 @@ +BEGIN; + +-- Check if brain_id column exists in chat_history table +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'chat_history' AND column_name = 'brain_id') THEN + -- Add brain_id column + ALTER TABLE chat_history ADD COLUMN brain_id UUID REFERENCES brains(brain_id); + END IF; +END $$; + +-- Check if prompt_id column exists in chat_history table +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'chat_history' AND column_name = 'prompt_id') THEN + -- Add prompt_id column + ALTER TABLE chat_history ADD COLUMN prompt_id UUID REFERENCES prompts(id); + END IF; +END $$; + +-- Update migrations table +INSERT INTO migrations (name) +SELECT '20230809154300_add_prompt_id_brain_id_to_chat_history_table' +WHERE NOT EXISTS ( + SELECT 1 FROM migrations WHERE name = '20230809154300_add_prompt_id_brain_id_to_chat_history_table' +); + +COMMIT; diff --git a/scripts/tables.sql b/scripts/tables.sql index f018f710c..7f449826a 100644 --- a/scripts/tables.sql +++ b/scripts/tables.sql @@ -16,15 +16,6 @@ CREATE TABLE IF NOT EXISTS chats( chat_name TEXT ); --- Create chat_history table -CREATE TABLE IF NOT EXISTS chat_history ( - message_id UUID DEFAULT uuid_generate_v4(), - chat_id UUID REFERENCES chats(chat_id), - user_message TEXT, - assistant TEXT, - message_time TIMESTAMP DEFAULT current_timestamp, - PRIMARY KEY (chat_id, message_id) -); -- Create vector extension CREATE EXTENSION IF NOT EXISTS vector; @@ -148,6 +139,18 @@ CREATE TABLE IF NOT EXISTS brains ( ); +-- Create chat_history table +CREATE TABLE IF NOT EXISTS chat_history ( + message_id UUID DEFAULT uuid_generate_v4(), + chat_id UUID REFERENCES chats(chat_id), + user_message TEXT, + assistant TEXT, + message_time TIMESTAMP DEFAULT current_timestamp, + PRIMARY KEY (chat_id, message_id), + prompt_id UUID REFERENCES prompts(id), + brain_id UUID REFERENCES brains(brain_id) +); + -- Create brains X users table CREATE TABLE IF NOT EXISTS brains_users ( brain_id UUID, @@ -212,7 +215,7 @@ CREATE TABLE IF NOT EXISTS migrations ( ); INSERT INTO migrations (name) -SELECT '20230802120700_add_prompt_id_to_brain' +SELECT '20230809154300_add_prompt_id_brain_id_to_chat_history_table' WHERE NOT EXISTS ( - SELECT 1 FROM migrations WHERE name = '20230802120700_add_prompt_id_to_brain' + SELECT 1 FROM migrations WHERE name = '20230809154300_add_prompt_id_brain_id_to_chat_history_table' );