refactor(brainpicking): removed one function

This commit is contained in:
Stan Girard 2023-06-19 23:14:42 +02:00
parent 969e0b48a8
commit 99258790ad
3 changed files with 44 additions and 31 deletions

View File

@ -59,7 +59,7 @@ class BrainPicking(BaseModel):
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
return self
def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
def _get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
if user_openai_api_key is not None and user_openai_api_key != "":
self.settings.openai_api_key = user_openai_api_key
qa = ConversationalRetrievalChain(
@ -67,3 +67,16 @@ class BrainPicking(BaseModel):
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
return qa
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
transformed_history = []
qa = self._get_qa(chat_message, user_openai_api_key)
for i in range(0, len(chat_message.history) - 1, 2):
user_message = chat_message.history[i][1]
assistant_message = chat_message.history[i + 1][1]
transformed_history.append((user_message, assistant_message))
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
answer = model_response['answer']
return answer

View File

@ -12,6 +12,7 @@ from utils.chats import (create_chat, get_chat_name_from_first_question,
from utils.users import (create_user, fetch_user_id_from_credentials,
update_user_request_count)
from utils.vectors import get_answer
from llm.brainpicking import BrainPicking
chat_router = APIRouter()
@ -101,8 +102,8 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
return {"history": history}
answer = get_answer(commons, chat_message, email, user_openai_api_key)
brainPicking = BrainPicking().init(chat_message.model, email)
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
history.append(("assistant", answer))
if is_new_chat:

View File

@ -56,37 +56,36 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_
Brain = BrainPicking().init(chat_message.model, email)
qa = Brain.get_qa(chat_message, user_openai_api_key)
neurons = Neurons(commons=commons)
if chat_message.use_summarization:
summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
evaluations = llm_evaluate_summaries(
chat_message.question, summaries, chat_message.model)
if evaluations:
response = commons['supabase'].from_('vectors').select(
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
additional_context = '---\nAdditional Context={}'.format(
'---\n'.join(data['content'] for data in response.data)
) + '\n'
model_response = qa(
{"question": additional_context + chat_message.question})
else:
transformed_history = []
# if chat_message.use_summarization:
# summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
# evaluations = llm_evaluate_summaries(
# chat_message.question, summaries, chat_message.model)
# if evaluations:
# response = commons['supabase'].from_('vectors').select(
# '*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
# additional_context = '---\nAdditional Context={}'.format(
# '---\n'.join(data['content'] for data in response.data)
# ) + '\n'
# model_response = qa(
# {"question": additional_context + chat_message.question})
# else:
# transformed_history = []
for i in range(0, len(chat_message.history) - 1, 2):
user_message = chat_message.history[i][1]
assistant_message = chat_message.history[i + 1][1]
transformed_history.append((user_message, assistant_message))
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
# for i in range(0, len(chat_message.history) - 1, 2):
# user_message = chat_message.history[i][1]
# assistant_message = chat_message.history[i + 1][1]
# transformed_history.append((user_message, assistant_message))
# model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
answer = model_response['answer']
# answer = model_response['answer']
if "source_documents" in answer:
sources = [
doc.metadata["file_name"] for doc in answer["source_documents"]
if "file_name" in doc.metadata]
if sources:
files = dict.fromkeys(sources)
answer = answer + "\n\nRef: " + "; ".join(files)
# if "source_documents" in answer:
# sources = [
# doc.metadata["file_name"] for doc in answer["source_documents"]
# if "file_name" in doc.metadata]
# if sources:
# files = dict.fromkeys(sources)
# answer = answer + "\n\nRef: " + "; ".join(files)
return answer