quivr/core/quivr_core/rag/utils.py
Jacopo Chevallard 285fe5b960
feat: websearch, tool use, user intent, dynamic retrieval, multiple questions (#3424)
# Description

This PR includes far too many new features:

- detection of user intent (closes CORE-211)
- treating multiple questions in parallel (closes CORE-212)
- using the chat history when answering a question (closes CORE-213)
- filtering of retrieved chunks by relevance threshold (closes CORE-217)
- dynamic retrieval of chunks (closes CORE-218)
- enabling web search via Tavily (closes CORE-220)
- enabling agent / assistant to activate tools when relevant to complete
the user task (closes CORE-224)

Also closes CORE-205

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Stan Girard <stan@quivr.app>
2024-10-31 17:57:54 +01:00

189 lines
6.5 KiB
Python

import logging
from typing import Any, List, Tuple, no_type_check
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from quivr_core.rag.entities.config import WorkflowConfig
from quivr_core.rag.entities.models import (
ChatLLMMetadata,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
RawRAGResponse,
)
from quivr_core.rag.prompts import custom_prompts
# TODO(@aminediro): define a types packages where we clearly define IO types
# This should be used for serialization/deseriallization later
logger = logging.getLogger("quivr_core")
def model_supports_function_calling(model_name: str):
models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"]
return model_name not in models_not_supporting_function_calls
def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))
for human, ai in tuple_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=question))
return messages
def cited_answer_filter(tool):
return tool["name"] == "cited_answer"
def get_chunk_metadata(
msg: AIMessageChunk, sources: list[Any] | None = None
) -> RAGResponseMetadata:
metadata = {"sources": sources or []}
if not msg.tool_calls:
return RAGResponseMetadata(**metadata, metadata_model=None)
all_citations = []
all_followup_questions = []
for tool_call in msg.tool_calls:
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
args = tool_call["args"]
all_citations.extend(args.get("citations", []))
all_followup_questions.extend(args.get("followup_questions", []))
metadata["citations"] = all_citations
metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3
return RAGResponseMetadata(**metadata, metadata_model=None)
def get_prev_message_str(msg: AIMessageChunk) -> str:
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""
# TODO: CONVOLUTED LOGIC !
# TODO(@aminediro): redo this
@no_type_check
def parse_chunk_response(
rolling_msg: AIMessageChunk,
raw_chunk: AIMessageChunk,
supports_func_calling: bool,
previous_content: str = "",
) -> Tuple[AIMessageChunk, str, str]:
"""Parse a chunk response
Args:
rolling_msg: The accumulated message so far
raw_chunk: The new chunk to add
supports_func_calling: Whether function calling is supported
previous_content: The previous content string
Returns:
Tuple of (updated rolling message, new content only, full content)
"""
rolling_msg += raw_chunk
if not supports_func_calling or not rolling_msg.tool_calls:
new_content = raw_chunk.content # Just the new chunk's content
full_content = rolling_msg.content # The full accumulated content
return rolling_msg, new_content, full_content
current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
full_answer = "\n\n".join(current_answers)
new_content = full_answer[len(previous_content) :]
return rolling_msg, new_content, full_answer
def get_answers_from_tool_calls(tool_calls):
answers = []
for tool_call in tool_calls:
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
answers.append(tool_call["args"].get("answer", ""))
return answers
@no_type_check
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
answers = []
sources = raw_response["docs"] if "docs" in raw_response else []
metadata = RAGResponseMetadata(
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
)
if (
model_supports_function_calling(model_name)
and "tool_calls" in raw_response["answer"]
and raw_response["answer"].tool_calls
):
all_citations = []
all_followup_questions = []
for tool_call in raw_response["answer"].tool_calls:
if "args" in tool_call:
args = tool_call["args"]
if "citations" in args:
all_citations.extend(args["citations"])
if "followup_questions" in args:
all_followup_questions.extend(args["followup_questions"])
if "answer" in args:
answers.append(args["answer"])
metadata.citations = all_citations
metadata.followup_questions = all_followup_questions
else:
answers.append(raw_response["answer"].content)
answer_str = "\n".join(answers)
parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata)
return parsed_response
def combine_documents(
docs,
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
document_separator="\n\n",
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs)), strict=False):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def format_file_list(
list_files_array: list[QuivrKnowledge], max_files: int = 20
) -> str:
list_files = [file.file_name or file.url for file in list_files_array]
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
files = files[:max_files]
files_str = "\n".join(files) if list_files_array else "None"
return files_str
def collect_tools(workflow_config: WorkflowConfig):
validated_tools = "Available tools which can be activated:\n"
for i, tool in enumerate(workflow_config.validated_tools):
validated_tools += f"Tool {i+1} name: {tool.name}\n"
validated_tools += f"Tool {i+1} description: {tool.description}\n\n"
activated_tools = "Activated tools which can be deactivated:\n"
for i, tool in enumerate(workflow_config.activated_tools):
activated_tools += f"Tool {i+1} name: {tool.name}\n"
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
return validated_tools, activated_tools