mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-02 08:40:53 +03:00
e71e46bcdf
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
import logging
|
|
from typing import Any, List, Tuple, no_type_check
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.messages.ai import AIMessageChunk
|
|
from langchain_core.prompts import format_document
|
|
|
|
from quivr_core.models import (
|
|
ChatLLMMetadata,
|
|
ParsedRAGResponse,
|
|
QuivrKnowledge,
|
|
RAGResponseMetadata,
|
|
RawRAGResponse,
|
|
)
|
|
from quivr_core.prompts import custom_prompts
|
|
|
|
# TODO(@aminediro): define a types packages where we clearly define IO types
|
|
# This should be used for serialization/deseriallization later
|
|
|
|
|
|
logger = logging.getLogger("quivr_core")
|
|
|
|
|
|
def model_supports_function_calling(model_name: str):
|
|
models_supporting_function_calls = [
|
|
"gpt-4",
|
|
"gpt-4-1106-preview",
|
|
"gpt-4-0613",
|
|
"gpt-4o",
|
|
"gpt-3.5-turbo-1106",
|
|
"gpt-3.5-turbo-0613",
|
|
"gpt-4-0125-preview",
|
|
"gpt-3.5-turbo",
|
|
"gpt-4-turbo",
|
|
"gpt-4o",
|
|
"gpt-4o-mini",
|
|
]
|
|
return model_name in models_supporting_function_calls
|
|
|
|
|
|
def format_history_to_openai_mesages(
|
|
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
|
) -> List[BaseMessage]:
|
|
"""Format the chat history into a list of Base Messages"""
|
|
messages = []
|
|
messages.append(SystemMessage(content=system_message))
|
|
for human, ai in tuple_history:
|
|
messages.append(HumanMessage(content=human))
|
|
messages.append(AIMessage(content=ai))
|
|
messages.append(HumanMessage(content=question))
|
|
return messages
|
|
|
|
|
|
def cited_answer_filter(tool):
|
|
return tool["name"] == "cited_answer"
|
|
|
|
|
|
def get_chunk_metadata(
|
|
msg: AIMessageChunk, sources: list[Any] | None = None
|
|
) -> RAGResponseMetadata:
|
|
# Initiate the source
|
|
metadata = {"sources": sources} if sources else {"sources": []}
|
|
if msg.tool_calls:
|
|
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
|
|
|
if "args" in cited_answer:
|
|
gathered_args = cited_answer["args"]
|
|
if "citations" in gathered_args:
|
|
citations = gathered_args["citations"]
|
|
metadata["citations"] = citations
|
|
|
|
if "followup_questions" in gathered_args:
|
|
followup_questions = gathered_args["followup_questions"]
|
|
metadata["followup_questions"] = followup_questions
|
|
|
|
return RAGResponseMetadata(**metadata, metadata_model=None)
|
|
|
|
|
|
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
|
if msg.tool_calls:
|
|
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
|
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
|
return cited_answer["args"]["answer"]
|
|
return ""
|
|
|
|
|
|
# TODO: CONVOLUTED LOGIC !
|
|
# TODO(@aminediro): redo this
|
|
@no_type_check
|
|
def parse_chunk_response(
|
|
rolling_msg: AIMessageChunk,
|
|
raw_chunk: dict[str, Any],
|
|
supports_func_calling: bool,
|
|
) -> Tuple[AIMessageChunk, str]:
|
|
# Init with sources
|
|
answer_str = ""
|
|
|
|
if "answer" in raw_chunk:
|
|
answer = raw_chunk["answer"]
|
|
else:
|
|
answer = raw_chunk
|
|
|
|
rolling_msg += answer
|
|
if supports_func_calling and rolling_msg.tool_calls:
|
|
cited_answer = next(x for x in rolling_msg.tool_calls if cited_answer_filter(x))
|
|
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
|
gathered_args = cited_answer["args"]
|
|
# Only send the difference between answer and response_tokens which was the previous answer
|
|
answer_str = gathered_args["answer"]
|
|
return rolling_msg, answer_str
|
|
|
|
return rolling_msg, answer.content
|
|
|
|
|
|
@no_type_check
|
|
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
|
answer = ""
|
|
sources = raw_response["docs"] if "docs" in raw_response else []
|
|
|
|
metadata = RAGResponseMetadata(
|
|
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
|
)
|
|
|
|
if (
|
|
model_supports_function_calling(model_name)
|
|
and "tool_calls" in raw_response["answer"]
|
|
and raw_response["answer"].tool_calls
|
|
):
|
|
if "citations" in raw_response["answer"].tool_calls[-1]["args"]:
|
|
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
|
|
metadata.citations = citations
|
|
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
|
"followup_questions"
|
|
]
|
|
if followup_questions:
|
|
metadata.followup_questions = followup_questions
|
|
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
|
else:
|
|
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
|
else:
|
|
answer = raw_response["answer"].content
|
|
|
|
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
|
|
return parsed_response
|
|
|
|
|
|
def combine_documents(
|
|
docs,
|
|
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
|
document_separator="\n\n",
|
|
):
|
|
# for each docs, add an index in the metadata to be able to cite the sources
|
|
for doc, index in zip(docs, range(len(docs)), strict=False):
|
|
doc.metadata["index"] = index
|
|
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
|
return document_separator.join(doc_strings)
|
|
|
|
|
|
def format_file_list(
|
|
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
|
) -> str:
|
|
list_files = [file.file_name or file.url for file in list_files_array]
|
|
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
|
files = files[:max_files]
|
|
|
|
files_str = "\n".join(files) if list_files_array else "None"
|
|
return files_str
|