mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 09:32:22 +03:00
feat: add brain prompt overwritting from chat (#1012)
This commit is contained in:
parent
4b1f4b1412
commit
b967c2d2d6
@ -1,7 +1,11 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from logger import get_logger
|
||||
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -22,6 +26,7 @@ class OpenAIBrainPicking(QABaseBrainPicking):
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
prompt_id: Optional[UUID],
|
||||
streaming: bool = False,
|
||||
) -> "OpenAIBrainPicking": # pyright: ignore reportPrivateUsage=none
|
||||
"""
|
||||
@ -36,6 +41,7 @@ class OpenAIBrainPicking(QABaseBrainPicking):
|
||||
temperature=temperature,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -2,9 +2,7 @@ import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, Awaitable, Optional
|
||||
from uuid import UUID
|
||||
from logger import get_logger
|
||||
|
||||
from supabase.client import Client, create_client
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
@ -15,26 +13,28 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
|
||||
from logger import get_logger
|
||||
from models import ChatQuestion
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from repository.brain import get_brain_by_id, get_brain_prompt_id
|
||||
from repository.brain import get_brain_by_id
|
||||
from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
format_chat_history,
|
||||
get_chat_history,
|
||||
update_chat_history,
|
||||
format_chat_history,
|
||||
GetChatHistoryOutput,
|
||||
update_message_by_id,
|
||||
)
|
||||
from repository.prompt import get_prompt_by_id
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
from .base import BaseBrainPicking
|
||||
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
|
||||
class QABaseBrainPicking(BaseBrainPicking):
|
||||
@ -50,6 +50,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
supabase_client: Optional[Client] = None
|
||||
vector_store: Optional[CustomSupabaseVectorStore] = None
|
||||
qa: Optional[ConversationalRetrievalChain] = None
|
||||
prompt_id: Optional[UUID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -57,6 +58,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -68,6 +70,15 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
)
|
||||
self.supabase_client = self._create_supabase_client()
|
||||
self.vector_store = self._create_vector_store()
|
||||
self.prompt_id = prompt_id
|
||||
|
||||
@property
|
||||
def prompt_to_use(self):
|
||||
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
|
||||
|
||||
@property
|
||||
def prompt_to_use_id(self) -> Optional[UUID]:
|
||||
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
|
||||
|
||||
def _create_supabase_client(self) -> Client:
|
||||
return create_client(
|
||||
@ -107,9 +118,13 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
|
||||
{context}"""
|
||||
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
|
||||
)
|
||||
|
||||
full_template = (
|
||||
"Here are you instructions to answer that you MUST ALWAYS Follow: "
|
||||
+ self.get_prompt()
|
||||
"Here are your instructions to answer that you MUST ALWAYS Follow: "
|
||||
+ prompt_content
|
||||
+ ". "
|
||||
+ system_template
|
||||
)
|
||||
@ -143,19 +158,20 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
|
||||
)
|
||||
|
||||
model_response = qa(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": self.get_prompt(),
|
||||
"custom_personality": prompt_content,
|
||||
}
|
||||
)
|
||||
|
||||
answer = model_response["answer"]
|
||||
|
||||
prompt_id = (
|
||||
get_brain_prompt_id(question.brain_id) if question.brain_id else None
|
||||
)
|
||||
new_chat = update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
@ -163,20 +179,15 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": prompt_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
brain = None
|
||||
prompt = None
|
||||
prompt_id = None
|
||||
|
||||
if question.brain_id:
|
||||
brain = get_brain_by_id(question.brain_id)
|
||||
if brain and brain.prompt_id:
|
||||
prompt = get_prompt_by_id(brain.prompt_id)
|
||||
prompt_id = prompt.id if prompt else None
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
@ -184,7 +195,9 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": prompt.title if prompt else None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
@ -228,13 +241,14 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
prompt_content = self.prompt_to_use.content if self.prompt_to_use else None
|
||||
run = asyncio.create_task(
|
||||
wrap_done(
|
||||
qa.acall(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": self.get_prompt(),
|
||||
"custom_personality": prompt_content,
|
||||
}
|
||||
),
|
||||
callback.done,
|
||||
@ -242,14 +256,9 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
)
|
||||
|
||||
brain = None
|
||||
prompt = None
|
||||
prompt_id = None
|
||||
|
||||
if question.brain_id:
|
||||
brain = get_brain_by_id(question.brain_id)
|
||||
if brain and brain.prompt_id:
|
||||
prompt = get_prompt_by_id(brain.prompt_id)
|
||||
prompt_id = prompt.id if prompt else None
|
||||
|
||||
streamed_chat_history = update_chat_history(
|
||||
CreateChatHistory(
|
||||
@ -258,7 +267,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": prompt_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -270,7 +279,9 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": prompt.title if prompt else None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
@ -289,14 +300,3 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
brain = get_brain_by_id(UUID(self.brain_id))
|
||||
brain_prompt = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
if brain and brain.prompt_id:
|
||||
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
|
||||
if brain_prompt_object:
|
||||
brain_prompt = brain_prompt_object.content
|
||||
|
||||
return brain_prompt
|
||||
|
@ -1,34 +1,35 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, Awaitable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from logger import get_logger
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from models.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
from repository.chat import (
|
||||
update_message_by_id,
|
||||
GetChatHistoryOutput,
|
||||
format_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
get_chat_history,
|
||||
update_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
GetChatHistoryOutput,
|
||||
update_message_by_id,
|
||||
)
|
||||
from logger import get_logger
|
||||
from models import ChatQuestion
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import AsyncIterable, Awaitable, List
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
|
||||
class HeadlessQA(BaseModel):
|
||||
@ -40,6 +41,7 @@ class HeadlessQA(BaseModel):
|
||||
streaming: bool = False
|
||||
chat_id: str = None # type: ignore
|
||||
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore
|
||||
prompt_id: Optional[UUID]
|
||||
|
||||
def _determine_api_key(self, openai_api_key, user_openai_api_key):
|
||||
"""If user provided an API key, use it."""
|
||||
@ -74,6 +76,14 @@ class HeadlessQA(BaseModel):
|
||||
self.streaming
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
@property
|
||||
def prompt_to_use(self) -> Optional[Prompt]:
|
||||
return get_prompt_to_use(None, self.prompt_id)
|
||||
|
||||
@property
|
||||
def prompt_to_use_id(self) -> Optional[UUID]:
|
||||
return get_prompt_to_use_id(None, self.prompt_id)
|
||||
|
||||
def _create_llm(
|
||||
self, model, temperature=0, streaming=False, callbacks=None
|
||||
) -> BaseLLM:
|
||||
@ -104,11 +114,19 @@ class HeadlessQA(BaseModel):
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
) -> GetChatHistoryOutput:
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
|
||||
messages = format_history_to_openai_mesages(
|
||||
transformed_history, prompt_content, question.question
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model, streaming=False, callbacks=self.callbacks
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(messages) # pyright: ignore reportPrivateUsage=none
|
||||
model_prediction = answering_llm.predict_messages(
|
||||
messages # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
answer = model_prediction.content
|
||||
|
||||
new_chat = update_chat_history(
|
||||
@ -118,7 +136,7 @@ class HeadlessQA(BaseModel):
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": None,
|
||||
"prompt_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -129,7 +147,9 @@ class HeadlessQA(BaseModel):
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
@ -142,7 +162,13 @@ class HeadlessQA(BaseModel):
|
||||
self.callbacks = [callback]
|
||||
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
|
||||
messages = format_history_to_openai_mesages(
|
||||
transformed_history, prompt_content, question.question
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model, streaming=True, callbacks=self.callbacks
|
||||
)
|
||||
@ -159,6 +185,7 @@ class HeadlessQA(BaseModel):
|
||||
logger.error(f"Caught exception: {e}")
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
run = asyncio.create_task(
|
||||
wrap_done(
|
||||
headlessChain.acall({}),
|
||||
@ -173,7 +200,7 @@ class HeadlessQA(BaseModel):
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": None,
|
||||
"prompt_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -185,15 +212,17 @@ class HeadlessQA(BaseModel):
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
|
||||
async for token in callback.aiter():
|
||||
logger.info("Token: %s", token) # type: ignore
|
||||
response_tokens.append(token) # type: ignore
|
||||
streamed_chat_history.assistant = token # type: ignore
|
||||
logger.info("Token: %s", token)
|
||||
response_tokens.append(token)
|
||||
streamed_chat_history.assistant = token
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
await run
|
||||
|
16
backend/llm/utils/get_prompt_to_use.py
Normal file
16
backend/llm/utils/get_prompt_to_use.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from models.prompt import Prompt
|
||||
from repository.prompt import get_prompt_by_id
|
||||
|
||||
|
||||
def get_prompt_to_use(
|
||||
brain_id: Optional[UUID], prompt_id: Optional[UUID]
|
||||
) -> Optional[Prompt]:
|
||||
prompt_to_use_id = get_prompt_to_use_id(brain_id, prompt_id)
|
||||
if prompt_to_use_id is None:
|
||||
return None
|
||||
|
||||
return get_prompt_by_id(prompt_to_use_id)
|
15
backend/llm/utils/get_prompt_to_use_id.py
Normal file
15
backend/llm/utils/get_prompt_to_use_id.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from repository.brain import get_brain_prompt_id
|
||||
|
||||
|
||||
def get_prompt_to_use_id(
|
||||
brain_id: Optional[UUID], prompt_id: Optional[UUID]
|
||||
) -> Optional[UUID]:
|
||||
if brain_id is None and prompt_id is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
prompt_id if prompt_id else get_brain_prompt_id(brain_id) if brain_id else None
|
||||
)
|
@ -22,3 +22,4 @@ class ChatQuestion(BaseModel):
|
||||
temperature: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
brain_id: Optional[UUID]
|
||||
prompt_id: Optional[UUID]
|
||||
|
@ -213,6 +213,7 @@ async def create_question_handler(
|
||||
temperature=chat_question.temperature,
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
else:
|
||||
gpt_answer_generator = HeadlessQA(
|
||||
@ -221,6 +222,7 @@ async def create_question_handler(
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
chat_id=str(chat_id),
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
|
||||
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
||||
@ -296,6 +298,7 @@ async def create_stream_question_handler(
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
else:
|
||||
gpt_answer_generator = HeadlessQA(
|
||||
@ -311,6 +314,7 @@ async def create_stream_question_handler(
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
chat_id=str(chat_id),
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
|
||||
print("streaming")
|
||||
|
@ -6,6 +6,7 @@ export type ChatQuestion = {
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
brain_id?: string;
|
||||
prompt_id?: string;
|
||||
};
|
||||
export type ChatHistory = {
|
||||
chat_id: string;
|
||||
|
Loading…
Reference in New Issue
Block a user