import logging from typing import Any, List, Tuple, no_type_check from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages.ai import AIMessageChunk from langchain_core.prompts import format_document from quivr_core.models import ( ChatLLMMetadata, ParsedRAGResponse, QuivrKnowledge, RAGResponseMetadata, RawRAGResponse, ) from quivr_core.prompts import DEFAULT_DOCUMENT_PROMPT # TODO(@aminediro): define a types packages where we clearly define IO types # This should be used for serialization/deseriallization later logger = logging.getLogger("quivr_core") def model_supports_function_calling(model_name: str): models_supporting_function_calls = [ "gpt-4", "gpt-4-1106-preview", "gpt-4-0613", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0613", "gpt-4-0125-preview", "gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o", "gpt-4o-mini", ] return model_name in models_supporting_function_calls def format_history_to_openai_mesages( tuple_history: List[Tuple[str, str]], system_message: str, question: str ) -> List[BaseMessage]: """Format the chat history into a list of Base Messages""" messages = [] messages.append(SystemMessage(content=system_message)) for human, ai in tuple_history: messages.append(HumanMessage(content=human)) messages.append(AIMessage(content=ai)) messages.append(HumanMessage(content=question)) return messages def cited_answer_filter(tool): return tool["name"] == "cited_answer" def get_chunk_metadata( msg: AIMessageChunk, sources: list[Any] = [] ) -> RAGResponseMetadata: # Initiate the source metadata = {"sources": sources} if msg.tool_calls: cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x)) if "args" in cited_answer: gathered_args = cited_answer["args"] if "citations" in gathered_args: citations = gathered_args["citations"] metadata["citations"] = citations if "followup_questions" in gathered_args: followup_questions = gathered_args["followup_questions"] metadata["followup_questions"] = followup_questions return RAGResponseMetadata(**metadata) def get_prev_message_str(msg: AIMessageChunk) -> str: if msg.tool_calls: cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x)) if "args" in cited_answer and "answer" in cited_answer["args"]: return cited_answer["args"]["answer"] return "" # TODO: CONVOLUTED LOGIC ! # TODO(@aminediro): redo this @no_type_check def parse_chunk_response( rolling_msg: AIMessageChunk, raw_chunk: dict[str, Any], supports_func_calling: bool, ) -> Tuple[AIMessageChunk, str]: # Init with sources answer_str = "" rolling_msg += raw_chunk["answer"] if supports_func_calling: if rolling_msg.tool_calls: cited_answer = next( x for x in rolling_msg.tool_calls if cited_answer_filter(x) ) if "args" in cited_answer: gathered_args = cited_answer["args"] if "answer" in gathered_args: # Only send the difference between answer and response_tokens which was the previous answer answer_str = gathered_args["answer"] return rolling_msg, answer_str else: return rolling_msg, raw_chunk["answer"].content @no_type_check def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse: answer = raw_response["answer"].content sources = raw_response["docs"] or [] metadata = RAGResponseMetadata( sources=sources, metadata_model=ChatLLMMetadata(name=model_name) ) if model_supports_function_calling(model_name): if raw_response["answer"].tool_calls: citations = raw_response["answer"].tool_calls[-1]["args"]["citations"] metadata.citations = citations followup_questions = raw_response["answer"].tool_calls[-1]["args"][ "followup_questions" ] if followup_questions: metadata.followup_questions = followup_questions parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata) return parsed_response def combine_documents( docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" ): # for each docs, add an index in the metadata to be able to cite the sources for doc, index in zip(docs, range(len(docs))): doc.metadata["index"] = index doc_strings = [format_document(doc, document_prompt) for doc in docs] return document_separator.join(doc_strings) def format_file_list( list_files_array: list[QuivrKnowledge], max_files: int = 20 ) -> str: list_files = [file.file_name or file.url for file in list_files_array] files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore files = files[:max_files] files_str = "\n".join(files) if list_files_array else "None" return files_str