diff --git a/backend/llm/qa.py b/backend/llm/qa.py index 12f9aafb9..2a05329c8 100644 --- a/backend/llm/qa.py +++ b/backend/llm/qa.py @@ -17,77 +17,59 @@ from langchain.vectorstores import SupabaseVectorStore from llm.prompt import LANGUAGE_PROMPT from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT from models.chats import ChatMessage +from pydantic import BaseModel, BaseSettings from supabase import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore +class BrainSettings(BaseSettings): + openai_api_key: str + anthropic_api_key: str + supabase_url: str + supabase_service_key: str + + class AnswerConversationBufferMemory(ConversationBufferMemory): """ref https://github.com/hwchase17/langchain/issues/5630#issuecomment-1574222564""" def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: return super(AnswerConversationBufferMemory, self).save_context( inputs, {'response': outputs['answer']}) - -def get_environment_variables(): - '''Get the environment variables.''' - openai_api_key = os.getenv("OPENAI_API_KEY") - anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") - supabase_url = os.getenv("SUPABASE_URL") - supabase_key = os.getenv("SUPABASE_SERVICE_KEY") - - return openai_api_key, anthropic_api_key, supabase_url, supabase_key - -def create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key): - '''Create the clients and embeddings.''' - embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) - supabase_client = create_client(supabase_url, supabase_key) - - return supabase_client, embeddings - def get_chat_history(inputs) -> str: res = [] for human, ai in inputs: res.append(f"{human}:{ai}\n") return "\n".join(res) -def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = False): - '''Get the question answering language model.''' - openai_api_key, anthropic_api_key, supabase_url, supabase_key = get_environment_variables() - - '''User can override the openai_api_key''' - if user_openai_api_key is not None and user_openai_api_key != "": - openai_api_key = user_openai_api_key - - supabase_client, embeddings = create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key) +class BrainPicking(BaseModel): + """ Class that allows the user to pick a brain. """ + llm_name: str = "gpt-3.5-turbo" + settings = BrainSettings() + embeddings: OpenAIEmbeddings = None + supabase_client: Client = None + vector_store: CustomSupabaseVectorStore = None + llm: ChatOpenAI = None + question_generator: LLMChain = None + doc_chain: ConversationalRetrievalChain = None - vector_store = CustomSupabaseVectorStore( - supabase_client, embeddings, table_name="vectors", user_id=user_id) - + class Config: + arbitrary_types_allowed = True + def init(self, model: str, user_id: str) -> "BrainPicking": + self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key) + self.supabase_client = create_client(self.settings.supabase_url, self.settings.supabase_service_key) + self.vector_store = CustomSupabaseVectorStore( + self.supabase_client, self.embeddings, table_name="vectors", user_id=user_id) + self.llm = ChatOpenAI(temperature=0, model_name=model) + self.question_generator = LLMChain(llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT) + self.doc_chain = load_qa_chain(self.llm, chain_type="stuff") + return self - qa = None - - if chat_message.model.startswith("gpt"): - llm = ChatOpenAI(temperature=0, model_name=chat_message.model) - question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) - doc_chain = load_qa_chain(llm, chain_type="stuff") - + def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain: + if user_openai_api_key is not None and user_openai_api_key != "": + self.settings.openai_api_key = user_openai_api_key qa = ConversationalRetrievalChain( - retriever=vector_store.as_retriever(), - max_tokens_limit=chat_message.max_tokens, question_generator=question_generator, - combine_docs_chain=doc_chain, get_chat_history=get_chat_history) - elif chat_message.model.startswith("vertex"): - qa = ConversationalRetrievalChain.from_llm( - ChatVertexAI(), vector_store.as_retriever(), verbose=True, - return_source_documents=with_sources, max_tokens_limit=1024,question_generator=question_generator, - combine_docs_chain=doc_chain) - elif anthropic_api_key and chat_message.model.startswith("claude"): - qa = ConversationalRetrievalChain.from_llm( - ChatAnthropic( - model=chat_message.model, anthropic_api_key=anthropic_api_key, temperature=chat_message.temperature, max_tokens_to_sample=chat_message.max_tokens), - vector_store.as_retriever(), verbose=False, - return_source_documents=with_sources, - max_tokens_limit=102400) - qa.combine_docs_chain = load_qa_chain(ChatAnthropic(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT) - - return qa + retriever=self.vector_store.as_retriever(), + max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator, + combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history) + return qa diff --git a/backend/utils/vectors.py b/backend/utils/vectors.py index 5fb86891e..6064554b9 100644 --- a/backend/utils/vectors.py +++ b/backend/utils/vectors.py @@ -1,6 +1,6 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document -from llm.qa import get_qa_llm +from llm.qa import BrainPicking from llm.summarization import llm_evaluate_summaries, llm_summerize from logger import get_logger from models.chats import ChatMessage @@ -49,7 +49,9 @@ def similarity_search(commons: CommonsDep, query, table='match_summaries', top_k return summaries.data def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_openai_api_key:str): - qa = get_qa_llm(chat_message, email, user_openai_api_key) + + Brain = BrainPicking().init(chat_message.model, email) + qa = Brain.get_qa(chat_message, user_openai_api_key) if chat_message.use_summarization: # 1. get summaries from the vector store based on question