From 41dec746a72e6b2753d5a295a29c10025d9148d9 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Wed, 29 May 2024 22:31:25 +0200 Subject: [PATCH] fix: Refactor conversational_qa_chain initialization in KnowledgeBrainQA (#2629) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- backend/modules/brain/knowledge_brain_qa.py | 7 ++++--- backend/modules/brain/rags/quivr_rag.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/backend/modules/brain/knowledge_brain_qa.py b/backend/modules/brain/knowledge_brain_qa.py index 72fabd30f..9b36d6455 100644 --- a/backend/modules/brain/knowledge_brain_qa.py +++ b/backend/modules/brain/knowledge_brain_qa.py @@ -331,9 +331,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface): async def generate_stream( self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True ) -> AsyncIterable: - conversational_qa_chain = ( - self.get_chain() if self.get_chain() else self.knowledge_qa.get_chain() - ) + if hasattr(self, "get_chain") and callable(getattr(self, "get_chain")): + conversational_qa_chain = self.get_chain() + else: + conversational_qa_chain = self.knowledge_qa.get_chain() transformed_history, streamed_chat_history = ( self.initialize_streamed_chat_history(chat_id, question) ) diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index 256c31682..cf99dc7b6 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -6,8 +6,7 @@ from uuid import UUID from langchain.chains import ConversationalRetrievalChain from langchain.llms.base import BaseLLM -from langchain.prompts import (HumanMessagePromptTemplate, - SystemMessagePromptTemplate) +from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import FlashrankRerank from langchain.schema import format_document @@ -89,7 +88,8 @@ system_message_template = ( ) system_message_template += """ -When answering use markdown neat. +When answering use markdown. +Use markdown code blocks for code snippets. Answer in a concise and clear manner. Use the following pieces of context from files provided by the user to answer the users. Answer in the same language as the user question.