fix: fixes duplicate response bug (#2843)

# Description

closes #2794.

Fixes duplicate responses in stream
This commit is contained in:
AmineDiro 2024-07-11 15:09:43 +02:00 committed by GitHub
parent 4b0d0f8144
commit 23ea00726a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 26 deletions

View File

@ -1,20 +1,24 @@
import logging
from operator import itemgetter
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional, Sequence
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain_openai import ChatOpenAI
from quivr_api.modules.knowledge.entity.knowledge import Knowledge
from quivr_api.packages.quivr_core.config import RAGConfig
from quivr_api.packages.quivr_core.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
RAGResponseMetadata,
cited_answer,
)
from quivr_api.packages.quivr_core.prompts import (
@ -33,6 +37,16 @@ from quivr_api.packages.quivr_core.utils import (
logger = logging.getLogger(__name__)
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
return documents
class QuivrQARAG:
def __init__(
self,
@ -55,13 +69,11 @@ class QuivrQARAG:
def _create_reranker(self):
# TODO: reranker config
compressor = CohereRerank(top_n=20)
# else:
# ranker_model_name = "ms-marco-TinyBERT-L-2-v2"
# flashrank_client = Ranker(model_name=ranker_model_name)
# compressor = FlashrankRerank(
# client=flashrank_client, model=ranker_model_name, top_n=20
# ) # TODO @stangirard fix
try:
compressor = CohereRerank(top_n=20)
except Exception as e:
logger.exception(f"Can't load Cohere reranker: {e}")
compressor = IdempotentCompressor()
return compressor
# TODO : refactor and simplify
@ -194,6 +206,7 @@ class QuivrQARAG:
rolling_message = AIMessageChunk(content="")
sources = []
prev_answer = ""
async for chunk in conversational_qa_chain.astream(
{
@ -208,18 +221,22 @@ class QuivrQARAG:
sources = chunk["docs"] if "docs" in chunk else []
if "answer" in chunk:
rolling_message, parsed_chunk = parse_chunk_response(
rolling_message, answer_str = parse_chunk_response(
rolling_message,
chunk,
self.supports_func_calling,
)
if self.supports_func_calling and len(parsed_chunk.answer) > 0:
yield parsed_chunk
else:
if self.supports_func_calling and len(answer_str) > 0:
diff_answer = answer_str[len(prev_answer) :]
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
yield parsed_chunk
# Last chunk provies
# Last chunk provides metadata
yield ParsedRAGChunkResponse(
answer="",
metadata=get_chunk_metadata(rolling_message, sources),

View File

@ -18,7 +18,6 @@ from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url,
)
from quivr_api.packages.quivr_core.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
RAGResponseMetadata,
RawRAGResponse,
@ -83,6 +82,7 @@ def get_prev_message_str(msg: AIMessageChunk) -> str:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""
@ -117,11 +117,9 @@ def parse_chunk_response(
gathered_msg: AIMessageChunk,
raw_chunk: dict[str, Any],
supports_func_calling: bool,
) -> Tuple[AIMessageChunk, ParsedRAGChunkResponse]:
) -> Tuple[AIMessageChunk, str]:
# Init with sources
answer_str = ""
# Get the previously parsed answer
prev_answer = get_prev_message_str(gathered_msg)
if supports_func_calling:
gathered_msg += raw_chunk["answer"]
@ -133,16 +131,10 @@ def parse_chunk_response(
gathered_args = cited_answer["args"]
if "answer" in gathered_args:
# Only send the difference between answer and response_tokens which was the previous answer
gathered_answer = gathered_args["answer"]
answer_str: str = gathered_answer[len(prev_answer) :]
return gathered_msg, ParsedRAGChunkResponse(
answer=answer_str, metadata=RAGResponseMetadata()
)
answer_str = gathered_args["answer"]
return gathered_msg, answer_str
else:
return gathered_msg, ParsedRAGChunkResponse(
answer=raw_chunk["answer"].content, metadata=RAGResponseMetadata()
)
return gathered_msg, raw_chunk["answer"].content
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
@ -165,6 +157,7 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
metadata["thoughts"] = thoughts
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
breakpoint()
parsed_response = ParsedRAGResponse(
answer=answer, metadata=RAGResponseMetadata(**metadata)
)