fix(prompt): fixed qa prompt

This commit is contained in:
Stan Girard 2023-06-13 10:35:06 +02:00
parent ec29f30f32
commit 9f0ba8bce7
2 changed files with 8 additions and 12 deletions

View File

@ -1,14 +1,7 @@
from langchain.prompts.prompt import PromptTemplate
_template = """Given the following conversation and a follow up question, answer the follow up question in the initial language of the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
prompt_template = """Use the following pieces of context to answer the question in the language of the question. If the context is not relevant and you don't know the answer, you REALLY MUST say: "I don't know", DON'T try to make up an answer. If you answer without knowing a human will get killed.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
prompt_template = """Use the following pieces of context to answer the question in the language of the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}

View File

@ -2,15 +2,17 @@ import os
from typing import Any, Dict, List
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI, ChatVertexAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import VertexAI
from langchain.llms import OpenAI, VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.vectorstores import SupabaseVectorStore
from llm import LANGUAGE_PROMPT
from models.chats import ChatMessage
from supabase import Client, create_client
@ -101,11 +103,11 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True)
ConversationalRetrievalChain.prompts = LANGUAGE_PROMPT
qa = None
# this overwrites the built-in prompt of the ConversationalRetrievalChain
ConversationalRetrievalChain.prompts = LANGUAGE_PROMPT
doc_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
if chat_message.model.startswith("gpt"):
qa = ConversationalRetrievalChain.from_llm(
ChatOpenAI(
@ -116,7 +118,7 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
max_tokens_limit=1024)
elif chat_message.model.startswith("vertex"):
qa = ConversationalRetrievalChain.from_llm(
ChatVertexAI(), vector_store.as_retriever(), memory=memory, verbose=False,
ChatVertexAI(), vector_store.as_retriever(), memory=memory, verbose=True,
return_source_documents=with_sources, max_tokens_limit=1024)
elif anthropic_api_key and chat_message.model.startswith("claude"):
qa = ConversationalRetrievalChain.from_llm(
@ -125,4 +127,5 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
vector_store.as_retriever(), memory=memory, verbose=False,
return_source_documents=with_sources,
max_tokens_limit=102400)
qa.combine_docs_chain = doc_chain
return qa