refactor(brainpicking): removed one function

This commit is contained in:
Stan Girard 2023-06-19 23:14:42 +02:00
parent 969e0b48a8
commit 99258790ad
3 changed files with 44 additions and 31 deletions

View File

@ -59,7 +59,7 @@ class BrainPicking(BaseModel):
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff") self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
return self return self
def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain: def _get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
if user_openai_api_key is not None and user_openai_api_key != "": if user_openai_api_key is not None and user_openai_api_key != "":
self.settings.openai_api_key = user_openai_api_key self.settings.openai_api_key = user_openai_api_key
qa = ConversationalRetrievalChain( qa = ConversationalRetrievalChain(
@ -67,3 +67,16 @@ class BrainPicking(BaseModel):
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator, max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history) combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
return qa return qa
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
transformed_history = []
qa = self._get_qa(chat_message, user_openai_api_key)
for i in range(0, len(chat_message.history) - 1, 2):
user_message = chat_message.history[i][1]
assistant_message = chat_message.history[i + 1][1]
transformed_history.append((user_message, assistant_message))
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
answer = model_response['answer']
return answer

View File

@ -12,6 +12,7 @@ from utils.chats import (create_chat, get_chat_name_from_first_question,
from utils.users import (create_user, fetch_user_id_from_credentials, from utils.users import (create_user, fetch_user_id_from_credentials,
update_user_request_count) update_user_request_count)
from utils.vectors import get_answer from utils.vectors import get_answer
from llm.brainpicking import BrainPicking
chat_router = APIRouter() chat_router = APIRouter()
@ -101,8 +102,8 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
return {"history": history} return {"history": history}
brainPicking = BrainPicking().init(chat_message.model, email)
answer = get_answer(commons, chat_message, email, user_openai_api_key) answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
history.append(("assistant", answer)) history.append(("assistant", answer))
if is_new_chat: if is_new_chat:

View File

@ -56,37 +56,36 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_
Brain = BrainPicking().init(chat_message.model, email) Brain = BrainPicking().init(chat_message.model, email)
qa = Brain.get_qa(chat_message, user_openai_api_key) qa = Brain.get_qa(chat_message, user_openai_api_key)
neurons = Neurons(commons=commons)
if chat_message.use_summarization: # if chat_message.use_summarization:
summaries = neurons.similarity_search(chat_message.question, table='match_summaries') # summaries = neurons.similarity_search(chat_message.question, table='match_summaries')
evaluations = llm_evaluate_summaries( # evaluations = llm_evaluate_summaries(
chat_message.question, summaries, chat_message.model) # chat_message.question, summaries, chat_message.model)
if evaluations: # if evaluations:
response = commons['supabase'].from_('vectors').select( # response = commons['supabase'].from_('vectors').select(
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute() # '*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
additional_context = '---\nAdditional Context={}'.format( # additional_context = '---\nAdditional Context={}'.format(
'---\n'.join(data['content'] for data in response.data) # '---\n'.join(data['content'] for data in response.data)
) + '\n' # ) + '\n'
model_response = qa( # model_response = qa(
{"question": additional_context + chat_message.question}) # {"question": additional_context + chat_message.question})
else: # else:
transformed_history = [] # transformed_history = []
for i in range(0, len(chat_message.history) - 1, 2): # for i in range(0, len(chat_message.history) - 1, 2):
user_message = chat_message.history[i][1] # user_message = chat_message.history[i][1]
assistant_message = chat_message.history[i + 1][1] # assistant_message = chat_message.history[i + 1][1]
transformed_history.append((user_message, assistant_message)) # transformed_history.append((user_message, assistant_message))
model_response = qa({"question": chat_message.question, "chat_history": transformed_history}) # model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
answer = model_response['answer'] # answer = model_response['answer']
if "source_documents" in answer: # if "source_documents" in answer:
sources = [ # sources = [
doc.metadata["file_name"] for doc in answer["source_documents"] # doc.metadata["file_name"] for doc in answer["source_documents"]
if "file_name" in doc.metadata] # if "file_name" in doc.metadata]
if sources: # if sources:
files = dict.fromkeys(sources) # files = dict.fromkeys(sources)
answer = answer + "\n\nRef: " + "; ".join(files) # answer = answer + "\n\nRef: " + "; ".join(files)
return answer return answer