mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat: chat with compositeBrain ( with/out streaming) (#1883)
# DONE - generate_stream, generate and save answer in BE # TODO - Create an intermediary make_streaming_recursive_tool_calls async function - Save intermediary answers in new message logs column then fetch and display in front
This commit is contained in:
parent
7c6c4cf10e
commit
742e9bdfba
@ -1,24 +1,21 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import nest_asyncio
|
||||
from fastapi import HTTPException
|
||||
from litellm import completion
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from llm.qa_interface import QAInterface
|
||||
from llm.utils.call_brain_api import call_brain_api
|
||||
from llm.utils.get_api_brain_definition_as_json_schema import (
|
||||
get_api_brain_definition_as_json_schema,
|
||||
)
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
@ -151,6 +148,7 @@ class APIBrainQA(KnowledgeBrainQA, QAInterface):
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool = True,
|
||||
should_log_steps: Optional[bool] = True,
|
||||
):
|
||||
if not question.brain_id:
|
||||
@ -181,30 +179,45 @@ class APIBrainQA(KnowledgeBrainQA, QAInterface):
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
response_tokens = []
|
||||
async for value in self.make_completion(
|
||||
messages=messages,
|
||||
@ -220,32 +233,165 @@ class APIBrainQA(KnowledgeBrainQA, QAInterface):
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<") and not token.endswith(">🧠")
|
||||
]
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
)
|
||||
|
||||
def make_completion_without_streaming(
|
||||
self,
|
||||
messages,
|
||||
functions,
|
||||
brain_id: UUID,
|
||||
recursive_count=0,
|
||||
should_log_steps=False,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
print(
|
||||
"The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
|
||||
)
|
||||
return
|
||||
|
||||
if should_log_steps:
|
||||
print("🧠<Deciding what to do>🧠")
|
||||
|
||||
response = completion(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
stream=False,
|
||||
function_call="auto",
|
||||
)
|
||||
|
||||
def generate_answer(self, chat_id: UUID, question: ChatQuestion):
|
||||
async def a_generate_answer():
|
||||
api_brain_question_answer: GetChatHistoryOutput = None
|
||||
response_message = response.choices[0].message
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
async for answer in self.generate_stream(
|
||||
chat_id, question, should_log_steps=False
|
||||
):
|
||||
answer = answer.split("data: ")[1]
|
||||
answer_parsed: GetChatHistoryOutput = GetChatHistoryOutput(
|
||||
**json.loads(answer)
|
||||
if finish_reason == "function_call":
|
||||
function_call = response_message.function_call
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
|
||||
except Exception:
|
||||
arguments = {}
|
||||
|
||||
if should_log_steps:
|
||||
print(f"🧠<Calling {brain_id} with arguments {arguments}>🧠")
|
||||
|
||||
try:
|
||||
api_call_response = call_brain_api(
|
||||
brain_id=brain_id,
|
||||
user_id=self.user_id,
|
||||
arguments=arguments,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Error while calling API: {e}",
|
||||
)
|
||||
if api_brain_question_answer is None:
|
||||
api_brain_question_answer = answer_parsed
|
||||
else:
|
||||
api_brain_question_answer.assistant += answer_parsed.assistant
|
||||
|
||||
return api_brain_question_answer
|
||||
function_name = function_call.name
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_call.name,
|
||||
"content": f"The function {function_name} was called and gave The following answer:(data from function) {api_call_response} (end of data from function). Don't call this function again unless there was an error or extremely necessary and asked specifically by the user.",
|
||||
}
|
||||
)
|
||||
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(a_generate_answer())
|
||||
return self.make_completion_without_streaming(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
brain_id=brain_id,
|
||||
recursive_count=recursive_count + 1,
|
||||
should_log_steps=should_log_steps,
|
||||
)
|
||||
|
||||
return result
|
||||
if finish_reason == "stop":
|
||||
return response_message
|
||||
|
||||
else:
|
||||
print("Never ending completion")
|
||||
|
||||
def generate_answer(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool = True,
|
||||
):
|
||||
if not question.brain_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No brain id provided in the question"
|
||||
)
|
||||
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
if not brain:
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
prompt_content = "You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer."
|
||||
|
||||
if self.prompt_to_use:
|
||||
prompt_content += self.prompt_to_use.content
|
||||
|
||||
messages = [{"role": "system", "content": prompt_content}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
response = self.make_completion_without_streaming(
|
||||
messages=messages,
|
||||
functions=[get_api_brain_definition_as_json_schema(brain)],
|
||||
brain_id=question.brain_id,
|
||||
should_log_steps=False,
|
||||
)
|
||||
|
||||
answer = response.content
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": "123",
|
||||
"prompt_title": None,
|
||||
"brain_name": brain.name,
|
||||
"message_id": None,
|
||||
}
|
||||
)
|
||||
|
589
backend/llm/composite_brain_qa.py
Normal file
589
backend/llm/composite_brain_qa.py
Normal file
@ -0,0 +1,589 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from litellm import completion
|
||||
from llm.api_brain_qa import APIBrainQA
|
||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from logger import get_logger
|
||||
from modules.brain.entity.brain_entity import BrainEntity, BrainType
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import (
|
||||
BrainCompletionOutput,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
GetChatHistoryOutput,
|
||||
)
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def format_brain_to_tool(brain):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(brain.id),
|
||||
"description": brain.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask the brain",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CompositeBrainQA(
|
||||
KnowledgeBrainQA,
|
||||
):
|
||||
user_id: UUID
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot find user id")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.user_id = user_id
|
||||
|
||||
def get_answer_generator_from_brain_type(self, brain: BrainEntity):
|
||||
if brain.brain_type == BrainType.COMPOSITE:
|
||||
return self.generate_answer
|
||||
elif brain.brain_type == BrainType.API:
|
||||
return APIBrainQA(
|
||||
brain_id=str(brain.id),
|
||||
chat_id=self.chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
user_id=str(self.user_id),
|
||||
).generate_answer
|
||||
elif brain.brain_type == BrainType.DOC:
|
||||
return KnowledgeBrainQA(
|
||||
brain_id=str(brain.id),
|
||||
chat_id=self.chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_answer
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool
|
||||
) -> str:
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
connected_brains = brain_service.get_connected_brains(self.brain_id)
|
||||
|
||||
if not connected_brains:
|
||||
response = HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_answer(chat_id, question, save_answer=False)
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"message_time": None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name,
|
||||
"message_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
available_functions = {}
|
||||
|
||||
connected_brains_details = {}
|
||||
for connected_brain_id in connected_brains:
|
||||
connected_brain = brain_service.get_brain_by_id(connected_brain_id)
|
||||
if connected_brain is None:
|
||||
continue
|
||||
|
||||
tools.append(format_brain_to_tool(connected_brain))
|
||||
|
||||
available_functions[
|
||||
connected_brain_id
|
||||
] = self.get_answer_generator_from_brain_type(connected_brain)
|
||||
|
||||
connected_brains_details[str(connected_brain.id)] = connected_brain
|
||||
|
||||
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
|
||||
"Based on the provided user content, find the most appropriate tools to answer"
|
||||
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
brain_completion_output = self.make_recursive_tool_calls(
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools,
|
||||
available_functions,
|
||||
recursive_count=0,
|
||||
last_completion_response=response.choices[0],
|
||||
)
|
||||
|
||||
if brain_completion_output:
|
||||
answer = brain_completion_output.response.message.content
|
||||
new_chat = None
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": brain_completion_output.response.message.content,
|
||||
"message_time": new_chat.message_time if new_chat else None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id if new_chat else None,
|
||||
}
|
||||
)
|
||||
|
||||
def make_recursive_tool_calls(
|
||||
self,
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools=[],
|
||||
available_functions={},
|
||||
recursive_count=0,
|
||||
last_completion_response: CompletionResponse = None,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
print(
|
||||
"The assistant is having issues and took more than 5 calls to the tools. Please try again later or an other instruction."
|
||||
)
|
||||
return None
|
||||
|
||||
finish_reason = last_completion_response.finish_reason
|
||||
if finish_reason == "stop":
|
||||
messages.append(last_completion_response.message)
|
||||
return BrainCompletionOutput(
|
||||
**{
|
||||
"messages": messages,
|
||||
"question": question.question,
|
||||
"response": last_completion_response,
|
||||
}
|
||||
)
|
||||
|
||||
if finish_reason == "tool_calls":
|
||||
response_message: CompletionMessage = last_completion_response.message
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
messages.append(response_message)
|
||||
|
||||
if (
|
||||
len(tool_calls) == 0
|
||||
or tool_calls is None
|
||||
or len(available_functions) == 0
|
||||
):
|
||||
return
|
||||
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
question = ChatQuestion(
|
||||
question=function_args["question"], brain_id=function_name
|
||||
)
|
||||
|
||||
print("querying brain", function_name)
|
||||
# TODO: extract chat_id from generate_answer function of XBrainQA
|
||||
function_response = function_to_call(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
save_answer=False,
|
||||
)
|
||||
|
||||
print("brain_answer", function_response.assistant)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response.assistant,
|
||||
}
|
||||
)
|
||||
|
||||
PROMPT_2 = "If initial question can be answered by our conversation messages, then give an answer and end the conversation."
|
||||
|
||||
messages.append({"role": "system", "content": PROMPT_2})
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
logger.info(
|
||||
f"Message {idx}: Role - {msg['role']}, Content - {msg['content']}"
|
||||
)
|
||||
|
||||
response_after_tools_answers = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
return self.make_recursive_tool_calls(
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools,
|
||||
available_functions,
|
||||
recursive_count=recursive_count + 1,
|
||||
last_completion_response=response_after_tools_answers.choices[0],
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
should_log_steps: Optional[bool] = True,
|
||||
):
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
connected_brains = brain_service.get_connected_brains(self.brain_id)
|
||||
|
||||
if not connected_brains:
|
||||
headlesss_answer = HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_stream(chat_id, question)
|
||||
|
||||
response_tokens = []
|
||||
async for value in headlesss_answer:
|
||||
streamed_chat_history.assistant = value
|
||||
response_tokens.append(value)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
)
|
||||
|
||||
tools = []
|
||||
available_functions = {}
|
||||
|
||||
connected_brains_details = {}
|
||||
for brain_id in connected_brains:
|
||||
brain = brain_service.get_brain_by_id(brain_id)
|
||||
if brain == None:
|
||||
continue
|
||||
|
||||
tools.append(format_brain_to_tool(brain))
|
||||
|
||||
available_functions[brain_id] = self.get_answer_generator_from_brain_type(
|
||||
brain
|
||||
)
|
||||
|
||||
connected_brains_details[str(brain.id)] = brain
|
||||
|
||||
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
|
||||
"Based on the provided user content, find the most appropriate tools to answer"
|
||||
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
if message.assistant is None:
|
||||
print(message)
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
initial_response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
stream=True,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
response_tokens = []
|
||||
tool_calls_aggregate = []
|
||||
for chunk in initial_response:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is not None:
|
||||
# Need to store it ?
|
||||
streamed_chat_history.assistant = content
|
||||
response_tokens.append(chunk.choices[0].delta.content)
|
||||
|
||||
if save_answer:
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
else:
|
||||
yield f"🧠<' {chunk.choices[0].delta.content}"
|
||||
|
||||
if (
|
||||
"tool_calls" in chunk.choices[0].delta
|
||||
and chunk.choices[0].delta.tool_calls is not None
|
||||
):
|
||||
tool_calls = chunk.choices[0].delta.tool_calls
|
||||
for tool_call in tool_calls:
|
||||
id = tool_call.id
|
||||
name = tool_call.function.name
|
||||
if id and name:
|
||||
tool_calls_aggregate += [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"arguments": tool_call.function.arguments,
|
||||
"name": tool_call.function.name,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
try:
|
||||
tool_calls_aggregate[tool_call.index]["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
except IndexError:
|
||||
print("TOOL_CALL_INDEX error", tool_call.index)
|
||||
print("TOOL_CALLS_AGGREGATE error", tool_calls_aggregate)
|
||||
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "stop":
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(
|
||||
[
|
||||
token
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<")
|
||||
]
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
if finish_reason == "tool_calls":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": tool_calls_aggregate,
|
||||
"content": None,
|
||||
}
|
||||
)
|
||||
for tool_call in tool_calls_aggregate:
|
||||
function_name = tool_call["function"]["name"]
|
||||
queried_brain = connected_brains_details[function_name]
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
print("function_args", function_args["question"])
|
||||
question = ChatQuestion(
|
||||
question=function_args["question"], brain_id=queried_brain.id
|
||||
)
|
||||
|
||||
# yield f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args} >🧠",
|
||||
|
||||
print(
|
||||
f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args}",
|
||||
)
|
||||
function_response = function_to_call(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
save_answer=False,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response.assistant,
|
||||
}
|
||||
)
|
||||
|
||||
print("messages", messages)
|
||||
|
||||
PROMPT_2 = "If the last user's question can be answered by our conversation messages since then, then give an answer and end the conversation. If you need to ask question to the user to gather more information and give a more accurate answer, then ask the question and wait for the user's answer."
|
||||
# Otherwise, ask a new question to the assistant and choose brains you would like to ask questions."
|
||||
|
||||
messages.append({"role": "system", "content": PROMPT_2})
|
||||
|
||||
response_after_tools_answers = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response_tokens = []
|
||||
for chunk in response_after_tools_answers:
|
||||
print("chunk_response_after_tools_answers", chunk)
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
streamed_chat_history.assistant = content
|
||||
response_tokens.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "stop":
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(
|
||||
[
|
||||
token
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<")
|
||||
]
|
||||
),
|
||||
)
|
||||
break
|
||||
elif finish_reason is not None:
|
||||
# TODO: recursively call with tools (update prompt + create intermediary function )
|
||||
print("NO STOP")
|
||||
print(chunk.choices[0])
|
@ -5,6 +5,12 @@ from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
from llm.qa_interface import QAInterface
|
||||
from llm.rags.quivr_rag import QuivrRAG
|
||||
from llm.rags.rag_interface import RAGInterface
|
||||
from llm.utils.format_chat_history import format_chat_history
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
@ -14,13 +20,6 @@ from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llm.qa_interface import QAInterface
|
||||
from llm.rags.quivr_rag import QuivrRAG
|
||||
from llm.rags.rag_interface import RAGInterface
|
||||
from llm.utils.format_chat_history import format_chat_history
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
@ -89,14 +88,16 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
|
||||
@property
|
||||
def prompt_to_use(self):
|
||||
# TODO: move to prompt service or instruction or something
|
||||
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
|
||||
|
||||
@property
|
||||
def prompt_to_use_id(self) -> Optional[UUID]:
|
||||
# TODO: move to prompt service or instruction or something
|
||||
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
@ -128,39 +129,55 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
|
||||
answer = model_response["answer"]
|
||||
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
brain = None
|
||||
|
||||
if question.brain_id:
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
if save_answer:
|
||||
# save the answer to the database or not -> add a variable
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"message_time": None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_name": None,
|
||||
"message_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
@ -211,31 +228,46 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
if question.brain_id:
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": brain.name if brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for token in callback.aiter():
|
||||
@ -274,10 +306,11 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
assistant += sources_string
|
||||
|
||||
try:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error updating message by ID: %s", e)
|
||||
|
@ -8,6 +8,13 @@ from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from llm.qa_interface import QAInterface
|
||||
from llm.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
)
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
@ -17,14 +24,6 @@ from modules.chat.service.chat_service import ChatService
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llm.qa_interface import QAInterface
|
||||
from llm.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
)
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
|
||||
chat_service = ChatService()
|
||||
@ -102,7 +101,7 @@ class HeadlessQA(BaseModel, QAInterface):
|
||||
return CHAT_PROMPT
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
# Move format_chat_history to chat service ?
|
||||
transformed_history = format_chat_history(
|
||||
@ -122,35 +121,49 @@ class HeadlessQA(BaseModel, QAInterface):
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(messages)
|
||||
answer = model_prediction.content
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": None,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
"message_id": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
self.callbacks = [callback]
|
||||
@ -191,31 +204,46 @@ class HeadlessQA(BaseModel, QAInterface):
|
||||
),
|
||||
)
|
||||
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": self.prompt_to_use.title
|
||||
if self.prompt_to_use
|
||||
else None,
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
|
||||
async for token in callback.aiter():
|
||||
logger.info("Token: %s", token)
|
||||
@ -226,11 +254,12 @@ class HeadlessQA(BaseModel, QAInterface):
|
||||
await run
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -12,7 +12,11 @@ class QAInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, should, *custom_params: tuple
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
*custom_params: tuple
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"generate_answer is an abstract method and must be implemented"
|
||||
@ -20,7 +24,11 @@ class QAInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, *custom_params: tuple
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
*custom_params: tuple
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"generate_stream is an abstract method and must be implemented"
|
||||
|
@ -235,6 +235,11 @@ class BrainService:
|
||||
)
|
||||
return brain
|
||||
|
||||
def get_connected_brains(self, brain_id: UUID) -> list[BrainEntity]:
|
||||
return self.composite_brains_connections_repository.get_connected_brains(
|
||||
brain_id
|
||||
)
|
||||
|
||||
def get_public_brains(self) -> list[PublicBrain]:
|
||||
return self.brain_repository.get_public_brains()
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from fastapi import HTTPException
|
||||
from llm.api_brain_qa import APIBrainQA
|
||||
from llm.composite_brain_qa import CompositeBrainQA
|
||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from modules.brain.entity.brain_entity import BrainType, RoleEnum
|
||||
from modules.brain.service.brain_authorization_service import (
|
||||
@ -58,6 +59,17 @@ class BrainfulChat(ChatInterface):
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
if brain.brain_type == BrainType.COMPOSITE:
|
||||
return CompositeBrainQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
brain_id=brain_id,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return APIBrainQA(
|
||||
chat_id=chat_id,
|
||||
|
@ -15,7 +15,6 @@ from modules.chat.dto.inputs import (
|
||||
CreateChatProperties,
|
||||
QuestionAndAnswer,
|
||||
)
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.entity.chat import Chat
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
@ -118,7 +117,7 @@ async def create_question_handler(
|
||||
| UUID
|
||||
| None = Query(..., description="The ID of the brain"),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> GetChatHistoryOutput:
|
||||
):
|
||||
"""
|
||||
Add a new question to the chat.
|
||||
"""
|
||||
@ -169,7 +168,9 @@ async def create_question_handler(
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
||||
chat_answer = gpt_answer_generator.generate_answer(
|
||||
chat_id, chat_question, save_answer=True
|
||||
)
|
||||
|
||||
return chat_answer
|
||||
except HTTPException as e:
|
||||
@ -244,7 +245,9 @@ async def create_stream_question_handler(
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
gpt_answer_generator.generate_stream(chat_id, chat_question),
|
||||
gpt_answer_generator.generate_stream(
|
||||
chat_id, chat_question, save_answer=True
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -6,10 +6,10 @@ from pydantic import BaseModel
|
||||
|
||||
class GetChatHistoryOutput(BaseModel):
|
||||
chat_id: UUID
|
||||
message_id: UUID
|
||||
message_id: Optional[UUID] | str
|
||||
user_message: str
|
||||
assistant: str
|
||||
message_time: str
|
||||
message_time: Optional[str]
|
||||
prompt_title: Optional[str] | None
|
||||
brain_name: Optional[str] | None
|
||||
|
||||
@ -19,3 +19,32 @@ class GetChatHistoryOutput(BaseModel):
|
||||
chat_history["message_id"] = str(chat_history.get("message_id"))
|
||||
|
||||
return chat_history
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
function: FunctionCall
|
||||
type: str = "function"
|
||||
|
||||
|
||||
class CompletionMessage(BaseModel):
|
||||
# = "assistant" | "user" | "system" | "tool"
|
||||
role: str
|
||||
content: str | None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
finish_reason: str
|
||||
message: CompletionMessage
|
||||
|
||||
|
||||
class BrainCompletionOutput(BaseModel):
|
||||
messages: List[CompletionMessage]
|
||||
question: str
|
||||
response: CompletionResponse
|
||||
|
Loading…
Reference in New Issue
Block a user