mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 07:59:00 +03:00
285fe5b960
# Description This PR includes far too many new features: - detection of user intent (closes CORE-211) - treating multiple questions in parallel (closes CORE-212) - using the chat history when answering a question (closes CORE-213) - filtering of retrieved chunks by relevance threshold (closes CORE-217) - dynamic retrieval of chunks (closes CORE-218) - enabling web search via Tavily (closes CORE-220) - enabling agent / assistant to activate tools when relevant to complete the user task (closes CORE-224) Also closes CORE-205 ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard <stan@quivr.app>
189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
import logging
|
|
from typing import Any, List, Tuple, no_type_check
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.messages.ai import AIMessageChunk
|
|
from langchain_core.prompts import format_document
|
|
|
|
from quivr_core.rag.entities.config import WorkflowConfig
|
|
from quivr_core.rag.entities.models import (
|
|
ChatLLMMetadata,
|
|
ParsedRAGResponse,
|
|
QuivrKnowledge,
|
|
RAGResponseMetadata,
|
|
RawRAGResponse,
|
|
)
|
|
from quivr_core.rag.prompts import custom_prompts
|
|
|
|
# TODO(@aminediro): define a types packages where we clearly define IO types
|
|
# This should be used for serialization/deseriallization later
|
|
|
|
|
|
logger = logging.getLogger("quivr_core")
|
|
|
|
|
|
def model_supports_function_calling(model_name: str):
|
|
models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"]
|
|
|
|
return model_name not in models_not_supporting_function_calls
|
|
|
|
|
|
def format_history_to_openai_mesages(
|
|
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
|
) -> List[BaseMessage]:
|
|
"""Format the chat history into a list of Base Messages"""
|
|
messages = []
|
|
messages.append(SystemMessage(content=system_message))
|
|
for human, ai in tuple_history:
|
|
messages.append(HumanMessage(content=human))
|
|
messages.append(AIMessage(content=ai))
|
|
messages.append(HumanMessage(content=question))
|
|
return messages
|
|
|
|
|
|
def cited_answer_filter(tool):
|
|
return tool["name"] == "cited_answer"
|
|
|
|
|
|
def get_chunk_metadata(
|
|
msg: AIMessageChunk, sources: list[Any] | None = None
|
|
) -> RAGResponseMetadata:
|
|
metadata = {"sources": sources or []}
|
|
|
|
if not msg.tool_calls:
|
|
return RAGResponseMetadata(**metadata, metadata_model=None)
|
|
|
|
all_citations = []
|
|
all_followup_questions = []
|
|
|
|
for tool_call in msg.tool_calls:
|
|
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
|
args = tool_call["args"]
|
|
all_citations.extend(args.get("citations", []))
|
|
all_followup_questions.extend(args.get("followup_questions", []))
|
|
|
|
metadata["citations"] = all_citations
|
|
metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3
|
|
|
|
return RAGResponseMetadata(**metadata, metadata_model=None)
|
|
|
|
|
|
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
|
if msg.tool_calls:
|
|
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
|
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
|
return cited_answer["args"]["answer"]
|
|
return ""
|
|
|
|
|
|
# TODO: CONVOLUTED LOGIC !
|
|
# TODO(@aminediro): redo this
|
|
@no_type_check
|
|
def parse_chunk_response(
|
|
rolling_msg: AIMessageChunk,
|
|
raw_chunk: AIMessageChunk,
|
|
supports_func_calling: bool,
|
|
previous_content: str = "",
|
|
) -> Tuple[AIMessageChunk, str, str]:
|
|
"""Parse a chunk response
|
|
Args:
|
|
rolling_msg: The accumulated message so far
|
|
raw_chunk: The new chunk to add
|
|
supports_func_calling: Whether function calling is supported
|
|
previous_content: The previous content string
|
|
Returns:
|
|
Tuple of (updated rolling message, new content only, full content)
|
|
"""
|
|
rolling_msg += raw_chunk
|
|
|
|
if not supports_func_calling or not rolling_msg.tool_calls:
|
|
new_content = raw_chunk.content # Just the new chunk's content
|
|
full_content = rolling_msg.content # The full accumulated content
|
|
return rolling_msg, new_content, full_content
|
|
|
|
current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
|
|
full_answer = "\n\n".join(current_answers)
|
|
new_content = full_answer[len(previous_content) :]
|
|
|
|
return rolling_msg, new_content, full_answer
|
|
|
|
|
|
def get_answers_from_tool_calls(tool_calls):
|
|
answers = []
|
|
for tool_call in tool_calls:
|
|
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
|
answers.append(tool_call["args"].get("answer", ""))
|
|
return answers
|
|
|
|
|
|
@no_type_check
|
|
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
|
answers = []
|
|
sources = raw_response["docs"] if "docs" in raw_response else []
|
|
|
|
metadata = RAGResponseMetadata(
|
|
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
|
)
|
|
|
|
if (
|
|
model_supports_function_calling(model_name)
|
|
and "tool_calls" in raw_response["answer"]
|
|
and raw_response["answer"].tool_calls
|
|
):
|
|
all_citations = []
|
|
all_followup_questions = []
|
|
for tool_call in raw_response["answer"].tool_calls:
|
|
if "args" in tool_call:
|
|
args = tool_call["args"]
|
|
if "citations" in args:
|
|
all_citations.extend(args["citations"])
|
|
if "followup_questions" in args:
|
|
all_followup_questions.extend(args["followup_questions"])
|
|
if "answer" in args:
|
|
answers.append(args["answer"])
|
|
metadata.citations = all_citations
|
|
metadata.followup_questions = all_followup_questions
|
|
else:
|
|
answers.append(raw_response["answer"].content)
|
|
|
|
answer_str = "\n".join(answers)
|
|
parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata)
|
|
return parsed_response
|
|
|
|
|
|
def combine_documents(
|
|
docs,
|
|
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
|
document_separator="\n\n",
|
|
):
|
|
# for each docs, add an index in the metadata to be able to cite the sources
|
|
for doc, index in zip(docs, range(len(docs)), strict=False):
|
|
doc.metadata["index"] = index
|
|
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
|
return document_separator.join(doc_strings)
|
|
|
|
|
|
def format_file_list(
|
|
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
|
) -> str:
|
|
list_files = [file.file_name or file.url for file in list_files_array]
|
|
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
|
files = files[:max_files]
|
|
|
|
files_str = "\n".join(files) if list_files_array else "None"
|
|
return files_str
|
|
|
|
|
|
def collect_tools(workflow_config: WorkflowConfig):
|
|
validated_tools = "Available tools which can be activated:\n"
|
|
for i, tool in enumerate(workflow_config.validated_tools):
|
|
validated_tools += f"Tool {i+1} name: {tool.name}\n"
|
|
validated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
|
|
|
activated_tools = "Activated tools which can be deactivated:\n"
|
|
for i, tool in enumerate(workflow_config.activated_tools):
|
|
activated_tools += f"Tool {i+1} name: {tool.name}\n"
|
|
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
|
|
|
return validated_tools, activated_tools
|