mirror of
https://github.com/StanGirard/quivr.git
synced 2025-01-04 01:03:45 +03:00
fix(memory): added memory back
This commit is contained in:
parent
8fb245fe2a
commit
82c74186a8
@ -82,7 +82,7 @@ def create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key):
|
||||
|
||||
return supabase_client, embeddings
|
||||
|
||||
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = True):
|
||||
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = False):
|
||||
'''Get the question answering language model.'''
|
||||
openai_api_key, anthropic_api_key, supabase_url, supabase_key = get_environment_variables()
|
||||
|
||||
@ -94,15 +94,10 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
|
||||
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="vectors", user_id=user_id)
|
||||
|
||||
|
||||
if with_sources:
|
||||
memory = AnswerConversationBufferMemory(
|
||||
memory_key="chat_history", return_messages=True)
|
||||
else:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="chat_history", return_messages=True)
|
||||
|
||||
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="chat_history", return_messages=True)
|
||||
|
||||
qa = None
|
||||
# this overwrites the built-in prompt of the ConversationalRetrievalChain
|
||||
@ -112,20 +107,20 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str
|
||||
ChatOpenAI(
|
||||
model_name=chat_message.model, openai_api_key=openai_api_key,
|
||||
temperature=chat_message.temperature, max_tokens=chat_message.max_tokens),
|
||||
vector_store.as_retriever(), memory=memory, verbose=True,
|
||||
vector_store.as_retriever(), verbose=True,
|
||||
return_source_documents=with_sources,
|
||||
max_tokens_limit=1024)
|
||||
qa.combine_docs_chain = load_qa_chain(OpenAI(temperature=chat_message.temperature, model_name=chat_message.model, max_tokens=chat_message.max_tokens), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
|
||||
elif chat_message.model.startswith("vertex"):
|
||||
qa = ConversationalRetrievalChain.from_llm(
|
||||
ChatVertexAI(), vector_store.as_retriever(), memory=memory, verbose=True,
|
||||
ChatVertexAI(), vector_store.as_retriever(), verbose=True,
|
||||
return_source_documents=with_sources, max_tokens_limit=1024)
|
||||
qa.combine_docs_chain = load_qa_chain(ChatVertexAI(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
|
||||
elif anthropic_api_key and chat_message.model.startswith("claude"):
|
||||
qa = ConversationalRetrievalChain.from_llm(
|
||||
ChatAnthropic(
|
||||
model=chat_message.model, anthropic_api_key=anthropic_api_key, temperature=chat_message.temperature, max_tokens_to_sample=chat_message.max_tokens),
|
||||
vector_store.as_retriever(), memory=memory, verbose=False,
|
||||
vector_store.as_retriever(), verbose=False,
|
||||
return_source_documents=with_sources,
|
||||
max_tokens_limit=102400)
|
||||
qa.combine_docs_chain = load_qa_chain(ChatAnthropic(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
|
||||
|
@ -10,7 +10,6 @@ from llm.summarization import llm_evaluate_summaries, llm_summerize
|
||||
from logger import get_logger
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
|
||||
from supabase import Client, create_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -156,7 +155,7 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user
|
||||
model_response = qa(
|
||||
{"question": additional_context + chat_message.question})
|
||||
else:
|
||||
model_response = qa({"question": chat_message.question})
|
||||
model_response = qa({"question": chat_message.question, "chat_history": chat_message.history})
|
||||
|
||||
answer = model_response['answer']
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user