feat(chatMessages): add brain_id and prompt_id columns (#912)

* feat: add prompt_id and brain_id to chat history)

* feat: add prompt_id and brain_id to chat routes
This commit is contained in:
Mamadou DICKO 2023-08-10 10:25:08 +02:00 committed by GitHub
parent 1360ce801d
commit 6e777327aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 270 additions and 106 deletions

View File

@ -6,22 +6,24 @@ from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from logger import get_logger
from models.chat import ChatHistory
from langchain.llms.base import BaseLLM
from langchain.chat_models import ChatOpenAI
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from repository.brain.get_brain_by_id import get_brain_by_id
from repository.brain.get_brain_prompt_id import get_brain_prompt_id
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from repository.prompt.get_prompt_by_id import get_prompt_by_id
from supabase.client import Client, create_client
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)
from vectorstore.supabase import CustomSupabaseVectorStore
from .base import BaseBrainPicking
@ -39,6 +41,7 @@ class QABaseBrainPicking(BaseBrainPicking):
Both are the same, except that the streaming version streams the last message as a stream.
Each have the same prompt template, which is defined in the `prompt_template` property.
"""
supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
qa: ConversationalRetrievalChain = None
@ -61,8 +64,6 @@ class QABaseBrainPicking(BaseBrainPicking):
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
@ -76,7 +77,9 @@ class QABaseBrainPicking(BaseBrainPicking):
brain_id=self.brain_id,
)
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
@ -94,13 +97,17 @@ class QABaseBrainPicking(BaseBrainPicking):
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
system_template = """You can use Markdown to make your answers nice. Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
----------------
{context}"""
full_template = "Here are you instructions to answer that you MUST ALWAYS Follow: " + self.get_prompt() + ". " + system_template
full_template = (
"Here are you instructions to answer that you MUST ALWAYS Follow: "
+ self.get_prompt()
+ ". "
+ system_template
)
messages = [
SystemMessagePromptTemplate.from_template(full_template),
HumanMessagePromptTemplate.from_template("{question}"),
@ -108,13 +115,18 @@ class QABaseBrainPicking(BaseBrainPicking):
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def generate_answer(self, question: str) -> ChatHistory:
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
answering_llm = self._create_llm(model=self.model,streaming=False, callbacks=self.callbacks)
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template())
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
@ -126,28 +138,68 @@ class QABaseBrainPicking(BaseBrainPicking):
verbose=True,
)
model_response = qa({
model_response = qa(
{
"question": question,
"chat_history": transformed_history,
"custom_personality": self.get_prompt(),
})
answer = model_response["answer"]
return update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant=answer,
}
)
async def generate_stream(self, question: str) -> AsyncIterable:
answer = model_response["answer"]
prompt_id = (
get_brain_prompt_id(question.brain_id) if question.brain_id else None
)
new_chat = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": question.brain_id,
"prompt_id": prompt_id,
}
)
)
brain = None
prompt = None
prompt_id = None
if question.brain_id:
brain = get_brain_by_id(question.brain_id)
if brain and brain.prompt_id:
prompt = get_prompt_by_id(brain.prompt_id)
prompt_id = prompt.id if prompt else None
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"message_time": new_chat.message_time,
"prompt_title": prompt.title if prompt else None,
"brain_name": brain.name if brain else None,
"message_id": new_chat.message_id,
}
)
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
history = get_chat_history(self.chat_id)
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
answering_llm = self._create_llm(model=self.model,streaming=True, callbacks=self.callbacks)
answering_llm = self._create_llm(
model=self.model, streaming=True, callbacks=self.callbacks
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template())
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
@ -184,24 +236,52 @@ class QABaseBrainPicking(BaseBrainPicking):
)
)
brain = None
prompt = None
prompt_id = None
if question.brain_id:
brain = get_brain_by_id(question.brain_id)
if brain and brain.prompt_id:
prompt = get_prompt_by_id(brain.prompt_id)
prompt_id = prompt.id if prompt else None
streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant="",
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": question.brain_id,
"prompt_id": prompt_id,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": prompt.title if prompt else None,
"brain_name": brain.name if brain else None,
}
)
async for token in callback.aiter():
logger.info("Token: %s", token)
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
yield f"data: {json.dumps(streamed_chat_history.dict())}"
await run
assistant = "".join(response_tokens)
update_message_by_id(
message_id=streamed_chat_history.message_id,
user_message=question,
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
)

View File

@ -1,4 +1,6 @@
from dataclasses import asdict, dataclass
from typing import Optional
from uuid import UUID
@dataclass
@ -9,18 +11,10 @@ class Chat:
chat_name: str
def __init__(self, chat_dict: dict):
self.chat_id = chat_dict.get(
"chat_id"
) # pyright: ignore reportPrivateUsage=none
self.user_id = chat_dict.get(
"user_id"
) # pyright: ignore reportPrivateUsage=none
self.creation_time = chat_dict.get(
"creation_time"
) # pyright: ignore reportPrivateUsage=none
self.chat_name = chat_dict.get(
"chat_name"
) # pyright: ignore reportPrivateUsage=none
self.chat_id = chat_dict.get("chat_id", "")
self.user_id = chat_dict.get("user_id", "")
self.creation_time = chat_dict.get("creation_time", "")
self.chat_name = chat_dict.get("chat_name", "")
@dataclass
@ -30,23 +24,18 @@ class ChatHistory:
user_message: str
assistant: str
message_time: str
prompt_id: Optional[UUID]
brain_id: Optional[UUID]
def __init__(self, chat_dict: dict):
self.chat_id = chat_dict.get(
"chat_id"
) # pyright: ignore reportPrivateUsage=none
self.message_id = chat_dict.get(
"message_id"
) # pyright: ignore reportPrivateUsage=none
self.user_message = chat_dict.get(
"user_message"
) # pyright: ignore reportPrivateUsage=none
self.assistant = chat_dict.get(
"assistant"
) # pyright: ignore reportPrivateUsage=none
self.message_time = chat_dict.get(
"message_time"
) # pyright: ignore reportPrivateUsage=none
self.chat_id = chat_dict.get("chat_id", "")
self.message_id = chat_dict.get("message_id", "")
self.user_message = chat_dict.get("user_message", "")
self.assistant = chat_dict.get("assistant", "")
self.message_time = chat_dict.get("message_time", "")
self.prompt_id = chat_dict.get("prompt_id")
self.brain_id = chat_dict.get("brain_id")
def to_dict(self):
return asdict(self)

View File

@ -21,3 +21,4 @@ class ChatQuestion(BaseModel):
question: str
temperature: float = 0.0
max_tokens: int = 256
brain_id: Optional[UUID]

View File

@ -1,4 +1,16 @@
from typing import Optional
from uuid import UUID
from models.databases.repository import Repository
from pydantic import BaseModel
class CreateChatHistory(BaseModel):
chat_id: UUID
user_message: str
assistant: str
prompt_id: Optional[UUID]
brain_id: Optional[UUID]
class Chats(Repository):
@ -38,14 +50,20 @@ class Chats(Repository):
)
return response
def update_chat_history(self, chat_id: str, user_message: str, assistant: str):
def update_chat_history(self, chat_history: CreateChatHistory):
response = (
self.db.table("chat_history")
.insert(
{
"chat_id": str(chat_id),
"user_message": user_message,
"assistant": assistant,
"chat_id": str(chat_history.chat_id),
"user_message": chat_history.user_message,
"assistant": chat_history.assistant,
"prompt_id": str(chat_history.prompt_id)
if chat_history.prompt_id
else None,
"brain_id": str(chat_history.brain_id)
if chat_history.brain_id
else None,
}
)
.execute()

View File

@ -0,0 +1,10 @@
from uuid import UUID
from repository.brain.get_brain_by_id import get_brain_by_id
def get_brain_prompt_id(brain_id: UUID) -> UUID | None:
brain = get_brain_by_id(brain_id)
prompt_id = brain.brain_id if brain else None
return prompt_id

View File

@ -1,16 +1,57 @@
from typing import List
from typing import List, Optional
from uuid import UUID
from models.chat import ChatHistory
from models.settings import get_supabase_db # For type hinting
from pydantic import BaseModel
from repository.brain.get_brain_by_id import get_brain_by_id
from repository.prompt.get_prompt_by_id import get_prompt_by_id
def get_chat_history(chat_id: str) -> List[ChatHistory]:
class GetChatHistoryOutput(BaseModel):
chat_id: UUID
message_id: UUID
user_message: str
assistant: str
message_time: str
prompt_title: Optional[str] | None
brain_name: Optional[str] | None
def dict(self, *args, **kwargs):
chat_history = super().dict(*args, **kwargs)
chat_history["chat_id"] = str(chat_history.get("prompt_id"))
chat_history["message_id"] = str(chat_history.get("message_id"))
return chat_history
def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]:
supabase_db = get_supabase_db()
history: List[ChatHistory] = supabase_db.get_chat_history(chat_id).data
history: List[dict] = supabase_db.get_chat_history(chat_id).data
if history is None:
return []
else:
return [
ChatHistory(message) # pyright: ignore reportPrivateUsage=none
for message in history
]
enriched_history: List[GetChatHistoryOutput] = []
for message in history:
message = ChatHistory(message)
brain = None
if message.brain_id:
brain = get_brain_by_id(message.brain_id)
prompt = None
if message.prompt_id:
prompt = get_prompt_by_id(message.prompt_id)
enriched_history.append(
GetChatHistoryOutput(
chat_id=(UUID(message.chat_id)),
message_id=(UUID(message.message_id)),
user_message=message.user_message,
assistant=message.assistant,
message_time=message.message_time,
brain_name=brain.name if brain else None,
prompt_title=prompt.title if prompt else None,
)
)
return enriched_history

View File

@ -1,15 +1,14 @@
from typing import List # For type hinting
from typing import List
from fastapi import HTTPException
from models.chat import ChatHistory
from models.databases.supabase.chats import CreateChatHistory
from models.settings import get_supabase_db
def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory:
def update_chat_history(chat_history: CreateChatHistory) -> ChatHistory:
supabase_db = get_supabase_db()
response: List[ChatHistory] = (
supabase_db.update_chat_history(chat_id, user_message, assistant)
).data
response: List[ChatHistory] = (supabase_db.update_chat_history(chat_history)).data
if len(response) == 0:
raise HTTPException(
status_code=500, detail="An exception occurred while updating chat history."

View File

@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from llm.openai import OpenAIBrainPicking
from models.brains import Brain
from models.chat import Chat, ChatHistory
from models.chat import Chat
from models.chats import ChatQuestion
from models.databases.supabase.supabase import SupabaseDB
from models.settings import LLMSettings, get_supabase_db
@ -20,7 +20,7 @@ from repository.brain.get_default_user_brain_or_create_new import (
)
from repository.chat.create_chat import CreateChatProperties, create_chat
from repository.chat.get_chat_by_id import get_chat_by_id
from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
from repository.user_identity.get_user_identity import get_user_identity
@ -85,7 +85,7 @@ async def get_chats(current_user: User = Depends(get_current_user)):
This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects
containing the chat ID and chat name for each chat.
"""
chats = get_user_chats(current_user.id) # pyright: ignore reportPrivateUsage=none
chats = get_user_chats(str(current_user.id))
return {"chats": chats}
@ -155,7 +155,7 @@ async def create_question_handler(
| UUID
| None = Query(..., description="The ID of the brain"),
current_user: User = Depends(get_current_user),
) -> ChatHistory:
) -> GetChatHistoryOutput:
"""
Add a new question to the chat.
"""
@ -163,8 +163,7 @@ async def create_question_handler(
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
if not current_user.user_openai_api_key:
if brain_id:
if not current_user.user_openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
current_user.user_openai_api_key = brain_details.openai_api_key
@ -202,9 +201,7 @@ async def create_question_handler(
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
)
chat_answer = gpt_answer_generator.generate_answer( # pyright: ignore reportPrivateUsage=none
chat_question.question
)
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
return chat_answer
except HTTPException as e:
@ -276,9 +273,7 @@ async def create_stream_question_handler(
print("streaming")
return StreamingResponse(
gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none
chat_question.question
),
gpt_answer_generator.generate_stream(chat_id, chat_question),
media_type="text/event-stream",
)
@ -292,6 +287,6 @@ async def create_stream_question_handler(
)
async def get_chat_history_handler(
chat_id: UUID,
) -> List[ChatHistory]:
) -> List[GetChatHistoryOutput]:
# TODO: RBAC with current_user
return get_chat_history(chat_id) # pyright: ignore reportPrivateUsage=none
return get_chat_history(str(chat_id))

View File

@ -0,0 +1,28 @@
BEGIN;
-- Check if brain_id column exists in chat_history table
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'chat_history' AND column_name = 'brain_id') THEN
-- Add brain_id column
ALTER TABLE chat_history ADD COLUMN brain_id UUID REFERENCES brains(brain_id);
END IF;
END $$;
-- Check if prompt_id column exists in chat_history table
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'chat_history' AND column_name = 'prompt_id') THEN
-- Add prompt_id column
ALTER TABLE chat_history ADD COLUMN prompt_id UUID REFERENCES prompts(id);
END IF;
END $$;
-- Update migrations table
INSERT INTO migrations (name)
SELECT '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
WHERE NOT EXISTS (
SELECT 1 FROM migrations WHERE name = '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
);
COMMIT;

View File

@ -16,15 +16,6 @@ CREATE TABLE IF NOT EXISTS chats(
chat_name TEXT
);
-- Create chat_history table
CREATE TABLE IF NOT EXISTS chat_history (
message_id UUID DEFAULT uuid_generate_v4(),
chat_id UUID REFERENCES chats(chat_id),
user_message TEXT,
assistant TEXT,
message_time TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (chat_id, message_id)
);
-- Create vector extension
CREATE EXTENSION IF NOT EXISTS vector;
@ -148,6 +139,18 @@ CREATE TABLE IF NOT EXISTS brains (
);
-- Create chat_history table
CREATE TABLE IF NOT EXISTS chat_history (
message_id UUID DEFAULT uuid_generate_v4(),
chat_id UUID REFERENCES chats(chat_id),
user_message TEXT,
assistant TEXT,
message_time TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (chat_id, message_id),
prompt_id UUID REFERENCES prompts(id),
brain_id UUID REFERENCES brains(brain_id)
);
-- Create brains X users table
CREATE TABLE IF NOT EXISTS brains_users (
brain_id UUID,
@ -212,7 +215,7 @@ CREATE TABLE IF NOT EXISTS migrations (
);
INSERT INTO migrations (name)
SELECT '20230802120700_add_prompt_id_to_brain'
SELECT '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
WHERE NOT EXISTS (
SELECT 1 FROM migrations WHERE name = '20230802120700_add_prompt_id_to_brain'
SELECT 1 FROM migrations WHERE name = '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
);