fix(memory): added memory back

This commit is contained in:
Stan Girard 2023-06-15 15:25:12 +02:00
parent 8fb245fe2a
commit 82c74186a8
2 changed files with 8 additions and 14 deletions

View File

@ -82,7 +82,7 @@ def create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key):
return supabase_client, embeddings
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = True):
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = False):
'''Get the question answering language model.'''
openai_api_key, anthropic_api_key, supabase_url, supabase_key = get_environment_variables()
@ -94,15 +94,10 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", user_id=user_id)
if with_sources:
memory = AnswerConversationBufferMemory(
memory_key="chat_history", return_messages=True)
else:
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True)
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True)
qa = None
# this overwrites the built-in prompt of the ConversationalRetrievalChain
@ -112,20 +107,20 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
ChatOpenAI(
model_name=chat_message.model, openai_api_key=openai_api_key,
temperature=chat_message.temperature, max_tokens=chat_message.max_tokens),
vector_store.as_retriever(), memory=memory, verbose=True,
vector_store.as_retriever(), verbose=True,
return_source_documents=with_sources,
max_tokens_limit=1024)
qa.combine_docs_chain = load_qa_chain(OpenAI(temperature=chat_message.temperature, model_name=chat_message.model, max_tokens=chat_message.max_tokens), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
elif chat_message.model.startswith("vertex"):
qa = ConversationalRetrievalChain.from_llm(
ChatVertexAI(), vector_store.as_retriever(), memory=memory, verbose=True,
ChatVertexAI(), vector_store.as_retriever(), verbose=True,
return_source_documents=with_sources, max_tokens_limit=1024)
qa.combine_docs_chain = load_qa_chain(ChatVertexAI(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
elif anthropic_api_key and chat_message.model.startswith("claude"):
qa = ConversationalRetrievalChain.from_llm(
ChatAnthropic(
model=chat_message.model, anthropic_api_key=anthropic_api_key, temperature=chat_message.temperature, max_tokens_to_sample=chat_message.max_tokens),
vector_store.as_retriever(), memory=memory, verbose=False,
vector_store.as_retriever(), verbose=False,
return_source_documents=with_sources,
max_tokens_limit=102400)
qa.combine_docs_chain = load_qa_chain(ChatAnthropic(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)

View File

@ -10,7 +10,6 @@ from llm.summarization import llm_evaluate_summaries, llm_summerize
from logger import get_logger
from models.chats import ChatMessage
from models.users import User
from supabase import Client, create_client
logger = get_logger(__name__)
@ -156,7 +155,7 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user
model_response = qa(
{"question": additional_context + chat_message.question})
else:
model_response = qa({"question": chat_message.question})
model_response = qa({"question": chat_message.question, "chat_history": chat_message.history})
answer = model_response['answer']