2023-06-20 09:56:17 +03:00
|
|
|
import os # A module to interact with the OS
|
|
|
|
from typing import Any, Dict, List # For type hinting
|
2023-05-31 14:51:23 +03:00
|
|
|
|
2023-06-20 09:56:17 +03:00
|
|
|
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
2023-06-17 02:16:11 +03:00
|
|
|
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
2023-06-13 11:35:06 +03:00
|
|
|
from langchain.chains.question_answering import load_qa_chain
|
2023-06-17 02:16:11 +03:00
|
|
|
from langchain.chains.router.llm_router import (LLMRouterChain,
|
|
|
|
RouterOutputParser)
|
|
|
|
from langchain.chains.router.multi_prompt_prompt import \
|
|
|
|
MULTI_PROMPT_ROUTER_TEMPLATE
|
2023-06-01 17:01:27 +03:00
|
|
|
from langchain.chat_models import ChatOpenAI, ChatVertexAI
|
2023-05-22 09:39:55 +03:00
|
|
|
from langchain.chat_models.anthropic import ChatAnthropic
|
2023-05-31 14:51:23 +03:00
|
|
|
from langchain.docstore.document import Document
|
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
2023-06-13 11:35:06 +03:00
|
|
|
from langchain.llms import OpenAI, VertexAI
|
2023-05-31 14:51:23 +03:00
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
from langchain.vectorstores import SupabaseVectorStore
|
2023-06-17 02:16:11 +03:00
|
|
|
from llm.prompt import LANGUAGE_PROMPT
|
|
|
|
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
2023-06-20 09:56:17 +03:00
|
|
|
from models.chats import \
|
|
|
|
ChatMessage # Importing a custom ChatMessage class for handling chat messages
|
|
|
|
from models.settings import \
|
|
|
|
BrainSettings # Importing settings related to the 'brain'
|
|
|
|
from pydantic import (BaseModel, # For data validation and settings management
|
|
|
|
BaseSettings)
|
|
|
|
from supabase import (Client, # For interacting with Supabase database
|
|
|
|
create_client)
|
|
|
|
from vectorstore.supabase import \
|
|
|
|
CustomSupabaseVectorStore # Custom class for handling vector storage with Supabase
|
2023-05-22 09:39:55 +03:00
|
|
|
|
2023-06-10 11:43:44 +03:00
|
|
|
|
|
|
|
class AnswerConversationBufferMemory(ConversationBufferMemory):
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
This class is a specialized version of ConversationBufferMemory.
|
|
|
|
It overrides the save_context method to save the response using the 'answer' key in the outputs.
|
|
|
|
Reference to some issue comment is given in the docstring.
|
|
|
|
"""
|
|
|
|
|
2023-06-10 11:43:44 +03:00
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
2023-06-20 09:56:17 +03:00
|
|
|
# Overriding the save_context method of the parent class
|
2023-06-10 11:43:44 +03:00
|
|
|
return super(AnswerConversationBufferMemory, self).save_context(
|
|
|
|
inputs, {'response': outputs['answer']})
|
|
|
|
|
2023-06-20 09:56:17 +03:00
|
|
|
|
2023-06-17 02:16:11 +03:00
|
|
|
def get_chat_history(inputs) -> str:
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
Function to concatenate chat history into a single string.
|
|
|
|
:param inputs: List of tuples containing human and AI messages.
|
|
|
|
:return: concatenated string of chat history
|
|
|
|
"""
|
2023-06-17 02:16:11 +03:00
|
|
|
res = []
|
|
|
|
for human, ai in inputs:
|
|
|
|
res.append(f"{human}:{ai}\n")
|
|
|
|
return "\n".join(res)
|
|
|
|
|
2023-06-20 09:56:17 +03:00
|
|
|
|
2023-06-19 21:51:13 +03:00
|
|
|
class BrainPicking(BaseModel):
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
Main class for the Brain Picking functionality.
|
|
|
|
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Default class attributes
|
2023-06-19 21:51:13 +03:00
|
|
|
llm_name: str = "gpt-3.5-turbo"
|
|
|
|
settings = BrainSettings()
|
|
|
|
embeddings: OpenAIEmbeddings = None
|
|
|
|
supabase_client: Client = None
|
|
|
|
vector_store: CustomSupabaseVectorStore = None
|
|
|
|
llm: ChatOpenAI = None
|
|
|
|
question_generator: LLMChain = None
|
|
|
|
doc_chain: ConversationalRetrievalChain = None
|
2023-05-31 14:51:23 +03:00
|
|
|
|
2023-06-19 21:51:13 +03:00
|
|
|
class Config:
|
2023-06-20 09:56:17 +03:00
|
|
|
# Allowing arbitrary types for class validation
|
2023-06-19 21:51:13 +03:00
|
|
|
arbitrary_types_allowed = True
|
2023-06-10 11:43:44 +03:00
|
|
|
|
2023-06-19 21:51:13 +03:00
|
|
|
def init(self, model: str, user_id: str) -> "BrainPicking":
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
|
|
|
:param model: Language model name to be used.
|
|
|
|
:param user_id: The user id to be used for CustomSupabaseVectorStore.
|
|
|
|
:return: BrainPicking instance
|
|
|
|
"""
|
2023-06-19 21:51:13 +03:00
|
|
|
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
|
|
|
self.supabase_client = create_client(self.settings.supabase_url, self.settings.supabase_service_key)
|
|
|
|
self.vector_store = CustomSupabaseVectorStore(
|
|
|
|
self.supabase_client, self.embeddings, table_name="vectors", user_id=user_id)
|
|
|
|
self.llm = ChatOpenAI(temperature=0, model_name=model)
|
|
|
|
self.question_generator = LLMChain(llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT)
|
|
|
|
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
|
|
|
return self
|
2023-05-31 14:51:23 +03:00
|
|
|
|
2023-06-20 00:14:42 +03:00
|
|
|
def _get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
Retrieves a QA chain for the given chat message and API key.
|
|
|
|
:param chat_message: The chat message containing history.
|
|
|
|
:param user_openai_api_key: The OpenAI API key to be used.
|
|
|
|
:return: ConversationalRetrievalChain instance
|
|
|
|
"""
|
|
|
|
# If user provided an API key, update the settings
|
2023-06-19 21:51:13 +03:00
|
|
|
if user_openai_api_key is not None and user_openai_api_key != "":
|
|
|
|
self.settings.openai_api_key = user_openai_api_key
|
2023-06-20 09:56:17 +03:00
|
|
|
|
|
|
|
# Initialize and return a ConversationalRetrievalChain
|
2023-06-17 02:16:11 +03:00
|
|
|
qa = ConversationalRetrievalChain(
|
2023-06-19 21:51:13 +03:00
|
|
|
retriever=self.vector_store.as_retriever(),
|
|
|
|
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
|
|
|
|
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
|
|
|
|
return qa
|
2023-06-20 00:14:42 +03:00
|
|
|
|
|
|
|
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
|
2023-06-20 09:56:17 +03:00
|
|
|
"""
|
|
|
|
Generate an answer to a given chat message by interacting with the language model.
|
|
|
|
:param chat_message: The chat message containing history.
|
|
|
|
:param user_openai_api_key: The OpenAI API key to be used.
|
|
|
|
:return: The generated answer.
|
|
|
|
"""
|
2023-06-20 00:14:42 +03:00
|
|
|
transformed_history = []
|
|
|
|
|
2023-06-20 09:56:17 +03:00
|
|
|
# Get the QA chain
|
2023-06-20 00:14:42 +03:00
|
|
|
qa = self._get_qa(chat_message, user_openai_api_key)
|
2023-06-20 09:56:17 +03:00
|
|
|
|
|
|
|
# Transform the chat history into a list of tuples
|
2023-06-20 00:14:42 +03:00
|
|
|
for i in range(0, len(chat_message.history) - 1, 2):
|
|
|
|
user_message = chat_message.history[i][1]
|
|
|
|
assistant_message = chat_message.history[i + 1][1]
|
|
|
|
transformed_history.append((user_message, assistant_message))
|
2023-06-20 09:56:17 +03:00
|
|
|
|
|
|
|
# Generate the model response using the QA chain
|
2023-06-20 00:14:42 +03:00
|
|
|
model_response = qa({"question": chat_message.question, "chat_history": transformed_history})
|
|
|
|
answer = model_response['answer']
|
|
|
|
|
2023-06-20 09:56:17 +03:00
|
|
|
return answer
|