mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
fix: fixes duplicate response bug (#2843)
# Description closes #2794. Fixes duplicate responses in stream
This commit is contained in:
parent
4b0d0f8144
commit
23ea00726a
@ -1,20 +1,24 @@
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional, Sequence
|
||||
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_cohere import CohereRerank
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from quivr_api.modules.knowledge.entity.knowledge import Knowledge
|
||||
from quivr_api.packages.quivr_core.config import RAGConfig
|
||||
from quivr_api.packages.quivr_core.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
RAGResponseMetadata,
|
||||
cited_answer,
|
||||
)
|
||||
from quivr_api.packages.quivr_core.prompts import (
|
||||
@ -33,6 +37,16 @@ from quivr_api.packages.quivr_core.utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IdempotentCompressor(BaseDocumentCompressor):
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
return documents
|
||||
|
||||
|
||||
class QuivrQARAG:
|
||||
def __init__(
|
||||
self,
|
||||
@ -55,13 +69,11 @@ class QuivrQARAG:
|
||||
|
||||
def _create_reranker(self):
|
||||
# TODO: reranker config
|
||||
compressor = CohereRerank(top_n=20)
|
||||
# else:
|
||||
# ranker_model_name = "ms-marco-TinyBERT-L-2-v2"
|
||||
# flashrank_client = Ranker(model_name=ranker_model_name)
|
||||
# compressor = FlashrankRerank(
|
||||
# client=flashrank_client, model=ranker_model_name, top_n=20
|
||||
# ) # TODO @stangirard fix
|
||||
try:
|
||||
compressor = CohereRerank(top_n=20)
|
||||
except Exception as e:
|
||||
logger.exception(f"Can't load Cohere reranker: {e}")
|
||||
compressor = IdempotentCompressor()
|
||||
return compressor
|
||||
|
||||
# TODO : refactor and simplify
|
||||
@ -194,6 +206,7 @@ class QuivrQARAG:
|
||||
|
||||
rolling_message = AIMessageChunk(content="")
|
||||
sources = []
|
||||
prev_answer = ""
|
||||
|
||||
async for chunk in conversational_qa_chain.astream(
|
||||
{
|
||||
@ -208,18 +221,22 @@ class QuivrQARAG:
|
||||
sources = chunk["docs"] if "docs" in chunk else []
|
||||
|
||||
if "answer" in chunk:
|
||||
rolling_message, parsed_chunk = parse_chunk_response(
|
||||
rolling_message, answer_str = parse_chunk_response(
|
||||
rolling_message,
|
||||
chunk,
|
||||
self.supports_func_calling,
|
||||
)
|
||||
|
||||
if self.supports_func_calling and len(parsed_chunk.answer) > 0:
|
||||
yield parsed_chunk
|
||||
else:
|
||||
if self.supports_func_calling and len(answer_str) > 0:
|
||||
diff_answer = answer_str[len(prev_answer) :]
|
||||
parsed_chunk = ParsedRAGChunkResponse(
|
||||
answer=diff_answer,
|
||||
metadata=RAGResponseMetadata(),
|
||||
)
|
||||
prev_answer += diff_answer
|
||||
yield parsed_chunk
|
||||
|
||||
# Last chunk provies
|
||||
# Last chunk provides metadata
|
||||
yield ParsedRAGChunkResponse(
|
||||
answer="",
|
||||
metadata=get_chunk_metadata(rolling_message, sources),
|
||||
|
@ -18,7 +18,6 @@ from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
)
|
||||
from quivr_api.packages.quivr_core.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
RAGResponseMetadata,
|
||||
RawRAGResponse,
|
||||
@ -83,6 +82,7 @@ def get_prev_message_str(msg: AIMessageChunk) -> str:
|
||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
||||
return cited_answer["args"]["answer"]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@ -117,11 +117,9 @@ def parse_chunk_response(
|
||||
gathered_msg: AIMessageChunk,
|
||||
raw_chunk: dict[str, Any],
|
||||
supports_func_calling: bool,
|
||||
) -> Tuple[AIMessageChunk, ParsedRAGChunkResponse]:
|
||||
) -> Tuple[AIMessageChunk, str]:
|
||||
# Init with sources
|
||||
answer_str = ""
|
||||
# Get the previously parsed answer
|
||||
prev_answer = get_prev_message_str(gathered_msg)
|
||||
|
||||
if supports_func_calling:
|
||||
gathered_msg += raw_chunk["answer"]
|
||||
@ -133,16 +131,10 @@ def parse_chunk_response(
|
||||
gathered_args = cited_answer["args"]
|
||||
if "answer" in gathered_args:
|
||||
# Only send the difference between answer and response_tokens which was the previous answer
|
||||
gathered_answer = gathered_args["answer"]
|
||||
answer_str: str = gathered_answer[len(prev_answer) :]
|
||||
|
||||
return gathered_msg, ParsedRAGChunkResponse(
|
||||
answer=answer_str, metadata=RAGResponseMetadata()
|
||||
)
|
||||
answer_str = gathered_args["answer"]
|
||||
return gathered_msg, answer_str
|
||||
else:
|
||||
return gathered_msg, ParsedRAGChunkResponse(
|
||||
answer=raw_chunk["answer"].content, metadata=RAGResponseMetadata()
|
||||
)
|
||||
return gathered_msg, raw_chunk["answer"].content
|
||||
|
||||
|
||||
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
||||
@ -165,6 +157,7 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
|
||||
metadata["thoughts"] = thoughts
|
||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
|
||||
breakpoint()
|
||||
parsed_response = ParsedRAGResponse(
|
||||
answer=answer, metadata=RAGResponseMetadata(**metadata)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user