quivr/backend/llm/composite_brain_qa.py
Stan Girard 07563bd079
fix: 🐛 models (#1973)
fixed when model is not choosen

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
2024-01-04 12:31:17 +01:00

590 lines
22 KiB
Python

import json
from typing import Optional
from uuid import UUID
from fastapi import HTTPException
from litellm import completion
from llm.api_brain_qa import APIBrainQA
from llm.knowledge_brain_qa import KnowledgeBrainQA
from llm.qa_headless import HeadlessQA
from logger import get_logger
from modules.brain.entity.brain_entity import BrainEntity, BrainType
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import (
BrainCompletionOutput,
CompletionMessage,
CompletionResponse,
GetChatHistoryOutput,
)
from modules.chat.service.chat_service import ChatService
brain_service = BrainService()
chat_service = ChatService()
logger = get_logger(__name__)
def format_brain_to_tool(brain):
return {
"type": "function",
"function": {
"name": str(brain.id),
"description": brain.description,
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "Question to ask the brain",
},
},
"required": ["question"],
},
},
}
class CompositeBrainQA(
KnowledgeBrainQA,
):
user_id: UUID
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
**kwargs,
):
user_id = kwargs.get("user_id")
if not user_id:
raise HTTPException(status_code=400, detail="Cannot find user id")
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
prompt_id=prompt_id,
**kwargs,
)
self.user_id = user_id
def get_answer_generator_from_brain_type(self, brain: BrainEntity):
if brain.brain_type == BrainType.COMPOSITE:
return self.generate_answer
elif brain.brain_type == BrainType.API:
return APIBrainQA(
brain_id=str(brain.id),
chat_id=self.chat_id,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
streaming=self.streaming,
prompt_id=self.prompt_id,
user_id=str(self.user_id),
).generate_answer
elif brain.brain_type == BrainType.DOC:
return KnowledgeBrainQA(
brain_id=str(brain.id),
chat_id=self.chat_id,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
streaming=self.streaming,
prompt_id=self.prompt_id,
).generate_answer
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool
) -> str:
brain = brain_service.get_brain_by_id(question.brain_id)
connected_brains = brain_service.get_connected_brains(self.brain_id)
if not connected_brains:
response = HeadlessQA(
chat_id=chat_id,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
streaming=self.streaming,
prompt_id=self.prompt_id,
).generate_answer(chat_id, question, save_answer=False)
if save_answer:
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": response.assistant,
"brain_id": question.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": response.assistant,
"message_time": new_chat.message_time,
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name,
"message_id": new_chat.message_id,
}
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": response.assistant,
"message_time": None,
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name,
"message_id": None,
}
)
tools = []
available_functions = {}
connected_brains_details = {}
for connected_brain_id in connected_brains:
connected_brain = brain_service.get_brain_by_id(connected_brain_id)
if connected_brain is None:
continue
tools.append(format_brain_to_tool(connected_brain))
available_functions[
connected_brain_id
] = self.get_answer_generator_from_brain_type(connected_brain)
connected_brains_details[str(connected_brain.id)] = connected_brain
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
"Based on the provided user content, find the most appropriate tools to answer"
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
)
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
history = chat_service.get_chat_history(self.chat_id)
for message in history:
formatted_message = [
{"role": "user", "content": message.user_message},
{"role": "assistant", "content": message.assistant},
]
messages.extend(formatted_message)
messages.append({"role": "user", "content": question.question})
response = completion(
model="gpt-3.5-turbo-1106",
messages=messages,
tools=tools,
tool_choice="auto",
)
brain_completion_output = self.make_recursive_tool_calls(
messages,
question,
chat_id,
tools,
available_functions,
recursive_count=0,
last_completion_response=response.choices[0],
)
if brain_completion_output:
answer = brain_completion_output.response.message.content
new_chat = None
if save_answer:
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": question.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": brain_completion_output.response.message.content,
"message_time": new_chat.message_time if new_chat else None,
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name if brain else None,
"message_id": new_chat.message_id if new_chat else None,
}
)
def make_recursive_tool_calls(
self,
messages,
question,
chat_id,
tools=[],
available_functions={},
recursive_count=0,
last_completion_response: CompletionResponse = None,
):
if recursive_count > 5:
print(
"The assistant is having issues and took more than 5 calls to the tools. Please try again later or an other instruction."
)
return None
finish_reason = last_completion_response.finish_reason
if finish_reason == "stop":
messages.append(last_completion_response.message)
return BrainCompletionOutput(
**{
"messages": messages,
"question": question.question,
"response": last_completion_response,
}
)
if finish_reason == "tool_calls":
response_message: CompletionMessage = last_completion_response.message
tool_calls = response_message.tool_calls
messages.append(response_message)
if (
len(tool_calls) == 0
or tool_calls is None
or len(available_functions) == 0
):
return
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
question = ChatQuestion(
question=function_args["question"], brain_id=function_name
)
print("querying brain", function_name)
# TODO: extract chat_id from generate_answer function of XBrainQA
function_response = function_to_call(
chat_id=chat_id,
question=question,
save_answer=False,
)
print("brain_answer", function_response.assistant)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response.assistant,
}
)
PROMPT_2 = "If initial question can be answered by our conversation messages, then give an answer and end the conversation."
messages.append({"role": "system", "content": PROMPT_2})
for idx, msg in enumerate(messages):
logger.info(
f"Message {idx}: Role - {msg['role']}, Content - {msg['content']}"
)
response_after_tools_answers = completion(
model="gpt-3.5-turbo-1106",
messages=messages,
tools=tools,
tool_choice="auto",
)
return self.make_recursive_tool_calls(
messages,
question,
chat_id,
tools,
available_functions,
recursive_count=recursive_count + 1,
last_completion_response=response_after_tools_answers.choices[0],
)
async def generate_stream(
self,
chat_id: UUID,
question: ChatQuestion,
save_answer: bool,
should_log_steps: Optional[bool] = True,
):
brain = brain_service.get_brain_by_id(question.brain_id)
if save_answer:
streamed_chat_history = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": question.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name if brain else None,
}
)
else:
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": None,
"message_time": None,
"user_message": question.question,
"assistant": "",
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name if brain else None,
}
)
connected_brains = brain_service.get_connected_brains(self.brain_id)
if not connected_brains:
headlesss_answer = HeadlessQA(
chat_id=chat_id,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
streaming=self.streaming,
prompt_id=self.prompt_id,
).generate_stream(chat_id, question)
response_tokens = []
async for value in headlesss_answer:
streamed_chat_history.assistant = value
response_tokens.append(value)
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if save_answer:
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant="".join(response_tokens),
)
tools = []
available_functions = {}
connected_brains_details = {}
for brain_id in connected_brains:
brain = brain_service.get_brain_by_id(brain_id)
if brain == None:
continue
tools.append(format_brain_to_tool(brain))
available_functions[brain_id] = self.get_answer_generator_from_brain_type(
brain
)
connected_brains_details[str(brain.id)] = brain
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
"Based on the provided user content, find the most appropriate tools to answer"
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
)
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
history = chat_service.get_chat_history(self.chat_id)
for message in history:
formatted_message = [
{"role": "user", "content": message.user_message},
{"role": "assistant", "content": message.assistant},
]
if message.assistant is None:
print(message)
messages.extend(formatted_message)
messages.append({"role": "user", "content": question.question})
initial_response = completion(
model="gpt-3.5-turbo-1106",
stream=True,
messages=messages,
tools=tools,
tool_choice="auto",
)
response_tokens = []
tool_calls_aggregate = []
for chunk in initial_response:
content = chunk.choices[0].delta.content
if content is not None:
# Need to store it ?
streamed_chat_history.assistant = content
response_tokens.append(chunk.choices[0].delta.content)
if save_answer:
yield f"data: {json.dumps(streamed_chat_history.dict())}"
else:
yield f"🧠<' {chunk.choices[0].delta.content}"
if (
"tool_calls" in chunk.choices[0].delta
and chunk.choices[0].delta.tool_calls is not None
):
tool_calls = chunk.choices[0].delta.tool_calls
for tool_call in tool_calls:
id = tool_call.id
name = tool_call.function.name
if id and name:
tool_calls_aggregate += [
{
"id": tool_call.id,
"function": {
"arguments": tool_call.function.arguments,
"name": tool_call.function.name,
},
"type": "function",
}
]
else:
try:
tool_calls_aggregate[tool_call.index]["function"][
"arguments"
] += tool_call.function.arguments
except IndexError:
print("TOOL_CALL_INDEX error", tool_call.index)
print("TOOL_CALLS_AGGREGATE error", tool_calls_aggregate)
finish_reason = chunk.choices[0].finish_reason
if finish_reason == "stop":
if save_answer:
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant="".join(
[
token
for token in response_tokens
if not token.startswith("🧠<")
]
),
)
break
if finish_reason == "tool_calls":
messages.append(
{
"role": "assistant",
"tool_calls": tool_calls_aggregate,
"content": None,
}
)
for tool_call in tool_calls_aggregate:
function_name = tool_call["function"]["name"]
queried_brain = connected_brains_details[function_name]
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call["function"]["arguments"])
print("function_args", function_args["question"])
question = ChatQuestion(
question=function_args["question"], brain_id=queried_brain.id
)
# yield f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args} >🧠",
print(
f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args}",
)
function_response = function_to_call(
chat_id=chat_id,
question=question,
save_answer=False,
)
messages.append(
{
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"content": function_response.assistant,
}
)
print("messages", messages)
PROMPT_2 = "If the last user's question can be answered by our conversation messages since then, then give an answer and end the conversation. If you need to ask question to the user to gather more information and give a more accurate answer, then ask the question and wait for the user's answer."
# Otherwise, ask a new question to the assistant and choose brains you would like to ask questions."
messages.append({"role": "system", "content": PROMPT_2})
response_after_tools_answers = completion(
model="gpt-3.5-turbo-1106",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
response_tokens = []
for chunk in response_after_tools_answers:
print("chunk_response_after_tools_answers", chunk)
content = chunk.choices[0].delta.content
if content:
streamed_chat_history.assistant = content
response_tokens.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(streamed_chat_history.dict())}"
finish_reason = chunk.choices[0].finish_reason
if finish_reason == "stop":
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant="".join(
[
token
for token in response_tokens
if not token.startswith("🧠<")
]
),
)
break
elif finish_reason is not None:
# TODO: recursively call with tools (update prompt + create intermediary function )
print("NO STOP")
print(chunk.choices[0])