mirror of
https://github.com/QuivrHQ/quivr.git
synced 2025-01-05 23:03:53 +03:00
refactor(brainpicking): removed one function
This commit is contained in:
parent
969e0b48a8
commit
99258790ad
@ -59,7 +59,7 @@ class BrainPicking(BaseModel):
|
|||||||
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
def _get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
||||||
if user_openai_api_key is not None and user_openai_api_key != "":
|
if user_openai_api_key is not None and user_openai_api_key != "":
|
||||||
self.settings.openai_api_key = user_openai_api_key
|
self.settings.openai_api_key = user_openai_api_key
|
||||||
qa = ConversationalRetrievalChain(
|
qa = ConversationalRetrievalChain(
|
||||||
@ -67,3 +67,16 @@ class BrainPicking(BaseModel):
|
|||||||
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
|
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
|
||||||
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
|
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
|
||||||
return qa
|
return qa
|
||||||
|
|
||||||
|
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
|
||||||
|
transformed_history = []
|
||||||
|
|
||||||
|
qa = self._get_qa(chat_message, user_openai_api_key)
|
||||||
|
for i in range(0, len(chat_message.history) - 1, 2):
|
||||||
|
user_message = chat_message.history[i][1]
|
||||||
|
assistant_message = chat_message.history[i + 1][1]
|
||||||
|
transformed_history.append((user_message, assistant_message))
|
||||||
|
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
||||||
|
answer = model_response['answer']
|
||||||
|
|
||||||
|
return answer
|
@ -12,6 +12,7 @@ from utils.chats import (create_chat, get_chat_name_from_first_question,
|
|||||||
from utils.users import (create_user, fetch_user_id_from_credentials,
|
from utils.users import (create_user, fetch_user_id_from_credentials,
|
||||||
update_user_request_count)
|
update_user_request_count)
|
||||||
from utils.vectors import get_answer
|
from utils.vectors import get_answer
|
||||||
|
from llm.brainpicking import BrainPicking
|
||||||
|
|
||||||
chat_router = APIRouter()
|
chat_router = APIRouter()
|
||||||
|
|
||||||
@ -101,8 +102,8 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
|
|||||||
return {"history": history}
|
return {"history": history}
|
||||||
|
|
||||||
|
|
||||||
|
brainPicking = BrainPicking().init(chat_message.model, email)
|
||||||
answer = get_answer(commons, chat_message, email, user_openai_api_key)
|
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
|
||||||
history.append(("assistant", answer))
|
history.append(("assistant", answer))
|
||||||
|
|
||||||
if is_new_chat:
|
if is_new_chat:
|
||||||
|
@ -56,37 +56,36 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_
|
|||||||
Brain = BrainPicking().init(chat_message.model, email)
|
Brain = BrainPicking().init(chat_message.model, email)
|
||||||
qa = Brain.get_qa(chat_message, user_openai_api_key)
|
qa = Brain.get_qa(chat_message, user_openai_api_key)
|
||||||
|
|
||||||
neurons = Neurons(commons=commons)
|
|
||||||
|
|
||||||
if chat_message.use_summarization:
|
# if chat_message.use_summarization:
|
||||||
summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
|
# summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
|
||||||
evaluations = llm_evaluate_summaries(
|
# evaluations = llm_evaluate_summaries(
|
||||||
chat_message.question, summaries, chat_message.model)
|
# chat_message.question, summaries, chat_message.model)
|
||||||
if evaluations:
|
# if evaluations:
|
||||||
response = commons['supabase'].from_('vectors').select(
|
# response = commons['supabase'].from_('vectors').select(
|
||||||
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
|
# '*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
|
||||||
additional_context = '---\nAdditional Context={}'.format(
|
# additional_context = '---\nAdditional Context={}'.format(
|
||||||
'---\n'.join(data['content'] for data in response.data)
|
# '---\n'.join(data['content'] for data in response.data)
|
||||||
) + '\n'
|
# ) + '\n'
|
||||||
model_response = qa(
|
# model_response = qa(
|
||||||
{"question": additional_context + chat_message.question})
|
# {"question": additional_context + chat_message.question})
|
||||||
else:
|
# else:
|
||||||
transformed_history = []
|
# transformed_history = []
|
||||||
|
|
||||||
for i in range(0, len(chat_message.history) - 1, 2):
|
# for i in range(0, len(chat_message.history) - 1, 2):
|
||||||
user_message = chat_message.history[i][1]
|
# user_message = chat_message.history[i][1]
|
||||||
assistant_message = chat_message.history[i + 1][1]
|
# assistant_message = chat_message.history[i + 1][1]
|
||||||
transformed_history.append((user_message, assistant_message))
|
# transformed_history.append((user_message, assistant_message))
|
||||||
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
# model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
||||||
|
|
||||||
answer = model_response['answer']
|
# answer = model_response['answer']
|
||||||
|
|
||||||
if "source_documents" in answer:
|
# if "source_documents" in answer:
|
||||||
sources = [
|
# sources = [
|
||||||
doc.metadata["file_name"] for doc in answer["source_documents"]
|
# doc.metadata["file_name"] for doc in answer["source_documents"]
|
||||||
if "file_name" in doc.metadata]
|
# if "file_name" in doc.metadata]
|
||||||
if sources:
|
# if sources:
|
||||||
files = dict.fromkeys(sources)
|
# files = dict.fromkeys(sources)
|
||||||
answer = answer + "\n\nRef: " + "; ".join(files)
|
# answer = answer + "\n\nRef: " + "; ".join(files)
|
||||||
|
|
||||||
return answer
|
return answer
|
Loading…
Reference in New Issue
Block a user