mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 17:43:03 +03:00
refactor(brainpicking): removed one function
This commit is contained in:
parent
969e0b48a8
commit
99258790ad
@ -59,7 +59,7 @@ class BrainPicking(BaseModel):
|
||||
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
||||
return self
|
||||
|
||||
def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
||||
def _get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
||||
if user_openai_api_key is not None and user_openai_api_key != "":
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
qa = ConversationalRetrievalChain(
|
||||
@ -67,3 +67,16 @@ class BrainPicking(BaseModel):
|
||||
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
|
||||
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
|
||||
return qa
|
||||
|
||||
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
|
||||
transformed_history = []
|
||||
|
||||
qa = self._get_qa(chat_message, user_openai_api_key)
|
||||
for i in range(0, len(chat_message.history) - 1, 2):
|
||||
user_message = chat_message.history[i][1]
|
||||
assistant_message = chat_message.history[i + 1][1]
|
||||
transformed_history.append((user_message, assistant_message))
|
||||
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
||||
answer = model_response['answer']
|
||||
|
||||
return answer
|
@ -12,6 +12,7 @@ from utils.chats import (create_chat, get_chat_name_from_first_question,
|
||||
from utils.users import (create_user, fetch_user_id_from_credentials,
|
||||
update_user_request_count)
|
||||
from utils.vectors import get_answer
|
||||
from llm.brainpicking import BrainPicking
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
@ -101,8 +102,8 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
|
||||
return {"history": history}
|
||||
|
||||
|
||||
|
||||
answer = get_answer(commons, chat_message, email, user_openai_api_key)
|
||||
brainPicking = BrainPicking().init(chat_message.model, email)
|
||||
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
|
||||
history.append(("assistant", answer))
|
||||
|
||||
if is_new_chat:
|
||||
|
@ -56,37 +56,36 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_
|
||||
Brain = BrainPicking().init(chat_message.model, email)
|
||||
qa = Brain.get_qa(chat_message, user_openai_api_key)
|
||||
|
||||
neurons = Neurons(commons=commons)
|
||||
|
||||
if chat_message.use_summarization:
|
||||
summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
|
||||
evaluations = llm_evaluate_summaries(
|
||||
chat_message.question, summaries, chat_message.model)
|
||||
if evaluations:
|
||||
response = commons['supabase'].from_('vectors').select(
|
||||
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
|
||||
additional_context = '---\nAdditional Context={}'.format(
|
||||
'---\n'.join(data['content'] for data in response.data)
|
||||
) + '\n'
|
||||
model_response = qa(
|
||||
{"question": additional_context + chat_message.question})
|
||||
else:
|
||||
transformed_history = []
|
||||
# if chat_message.use_summarization:
|
||||
# summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
|
||||
# evaluations = llm_evaluate_summaries(
|
||||
# chat_message.question, summaries, chat_message.model)
|
||||
# if evaluations:
|
||||
# response = commons['supabase'].from_('vectors').select(
|
||||
# '*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
|
||||
# additional_context = '---\nAdditional Context={}'.format(
|
||||
# '---\n'.join(data['content'] for data in response.data)
|
||||
# ) + '\n'
|
||||
# model_response = qa(
|
||||
# {"question": additional_context + chat_message.question})
|
||||
# else:
|
||||
# transformed_history = []
|
||||
|
||||
for i in range(0, len(chat_message.history) - 1, 2):
|
||||
user_message = chat_message.history[i][1]
|
||||
assistant_message = chat_message.history[i + 1][1]
|
||||
transformed_history.append((user_message, assistant_message))
|
||||
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
||||
# for i in range(0, len(chat_message.history) - 1, 2):
|
||||
# user_message = chat_message.history[i][1]
|
||||
# assistant_message = chat_message.history[i + 1][1]
|
||||
# transformed_history.append((user_message, assistant_message))
|
||||
# model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
||||
|
||||
answer = model_response['answer']
|
||||
# answer = model_response['answer']
|
||||
|
||||
if "source_documents" in answer:
|
||||
sources = [
|
||||
doc.metadata["file_name"] for doc in answer["source_documents"]
|
||||
if "file_name" in doc.metadata]
|
||||
if sources:
|
||||
files = dict.fromkeys(sources)
|
||||
answer = answer + "\n\nRef: " + "; ".join(files)
|
||||
# if "source_documents" in answer:
|
||||
# sources = [
|
||||
# doc.metadata["file_name"] for doc in answer["source_documents"]
|
||||
# if "file_name" in doc.metadata]
|
||||
# if sources:
|
||||
# files = dict.fromkeys(sources)
|
||||
# answer = answer + "\n\nRef: " + "; ".join(files)
|
||||
|
||||
return answer
|
Loading…
Reference in New Issue
Block a user