2023-05-22 09:39:55 +03:00
|
|
|
import os
|
2023-06-10 11:43:44 +03:00
|
|
|
from typing import Any, Dict, List
|
2023-05-31 14:51:23 +03:00
|
|
|
|
2023-06-17 02:16:11 +03:00
|
|
|
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
2023-06-13 11:35:06 +03:00
|
|
|
from langchain.chains.question_answering import load_qa_chain
|
2023-06-17 02:16:11 +03:00
|
|
|
from langchain.chains.router.llm_router import (LLMRouterChain,
|
|
|
|
RouterOutputParser)
|
|
|
|
from langchain.chains.router.multi_prompt_prompt import \
|
|
|
|
MULTI_PROMPT_ROUTER_TEMPLATE
|
2023-06-01 17:01:27 +03:00
|
|
|
from langchain.chat_models import ChatOpenAI, ChatVertexAI
|
2023-05-22 09:39:55 +03:00
|
|
|
from langchain.chat_models.anthropic import ChatAnthropic
|
2023-05-31 14:51:23 +03:00
|
|
|
from langchain.docstore.document import Document
|
|
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
2023-06-13 11:35:06 +03:00
|
|
|
from langchain.llms import OpenAI, VertexAI
|
2023-05-31 14:51:23 +03:00
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
from langchain.vectorstores import SupabaseVectorStore
|
2023-06-17 02:16:11 +03:00
|
|
|
from llm.prompt import LANGUAGE_PROMPT
|
|
|
|
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
2023-06-04 00:12:42 +03:00
|
|
|
from models.chats import ChatMessage
|
2023-05-31 14:51:23 +03:00
|
|
|
from supabase import Client, create_client
|
2023-06-19 21:15:34 +03:00
|
|
|
from vectorstore.supabase import CustomSupabaseVectorStore
|
2023-05-22 09:39:55 +03:00
|
|
|
|
2023-06-10 11:43:44 +03:00
|
|
|
|
|
|
|
class AnswerConversationBufferMemory(ConversationBufferMemory):
|
|
|
|
"""ref https://github.com/hwchase17/langchain/issues/5630#issuecomment-1574222564"""
|
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
|
|
return super(AnswerConversationBufferMemory, self).save_context(
|
|
|
|
inputs, {'response': outputs['answer']})
|
|
|
|
|
|
|
|
|
2023-05-31 14:51:23 +03:00
|
|
|
def get_environment_variables():
|
|
|
|
'''Get the environment variables.'''
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
|
|
supabase_url = os.getenv("SUPABASE_URL")
|
|
|
|
supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
|
|
|
|
|
|
|
|
return openai_api_key, anthropic_api_key, supabase_url, supabase_key
|
|
|
|
|
|
|
|
def create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key):
|
|
|
|
'''Create the clients and embeddings.'''
|
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
|
|
|
supabase_client = create_client(supabase_url, supabase_key)
|
|
|
|
|
|
|
|
return supabase_client, embeddings
|
|
|
|
|
2023-06-17 02:16:11 +03:00
|
|
|
def get_chat_history(inputs) -> str:
|
|
|
|
res = []
|
|
|
|
for human, ai in inputs:
|
|
|
|
res.append(f"{human}:{ai}\n")
|
|
|
|
return "\n".join(res)
|
|
|
|
|
2023-06-15 16:25:12 +03:00
|
|
|
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = False):
|
2023-05-31 14:51:23 +03:00
|
|
|
'''Get the question answering language model.'''
|
|
|
|
openai_api_key, anthropic_api_key, supabase_url, supabase_key = get_environment_variables()
|
2023-06-09 19:49:47 +03:00
|
|
|
|
|
|
|
'''User can override the openai_api_key'''
|
|
|
|
if user_openai_api_key is not None and user_openai_api_key != "":
|
|
|
|
openai_api_key = user_openai_api_key
|
|
|
|
|
2023-05-31 14:51:23 +03:00
|
|
|
supabase_client, embeddings = create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key)
|
|
|
|
|
|
|
|
vector_store = CustomSupabaseVectorStore(
|
|
|
|
supabase_client, embeddings, table_name="vectors", user_id=user_id)
|
2023-06-15 16:25:12 +03:00
|
|
|
|
2023-06-10 11:43:44 +03:00
|
|
|
|
2023-05-31 14:51:23 +03:00
|
|
|
|
2023-05-22 09:39:55 +03:00
|
|
|
qa = None
|
2023-06-17 02:16:11 +03:00
|
|
|
|
2023-05-22 09:39:55 +03:00
|
|
|
if chat_message.model.startswith("gpt"):
|
2023-06-17 02:16:11 +03:00
|
|
|
llm = ChatOpenAI(temperature=0, model_name=chat_message.model)
|
|
|
|
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
2023-06-19 18:53:07 +03:00
|
|
|
doc_chain = load_qa_chain(llm, chain_type="stuff")
|
2023-06-17 02:16:11 +03:00
|
|
|
|
|
|
|
qa = ConversationalRetrievalChain(
|
|
|
|
retriever=vector_store.as_retriever(),
|
|
|
|
max_tokens_limit=chat_message.max_tokens, question_generator=question_generator,
|
|
|
|
combine_docs_chain=doc_chain, get_chat_history=get_chat_history)
|
2023-06-01 17:01:27 +03:00
|
|
|
elif chat_message.model.startswith("vertex"):
|
|
|
|
qa = ConversationalRetrievalChain.from_llm(
|
2023-06-15 16:25:12 +03:00
|
|
|
ChatVertexAI(), vector_store.as_retriever(), verbose=True,
|
2023-06-17 02:16:11 +03:00
|
|
|
return_source_documents=with_sources, max_tokens_limit=1024,question_generator=question_generator,
|
|
|
|
combine_docs_chain=doc_chain)
|
2023-05-22 09:39:55 +03:00
|
|
|
elif anthropic_api_key and chat_message.model.startswith("claude"):
|
|
|
|
qa = ConversationalRetrievalChain.from_llm(
|
|
|
|
ChatAnthropic(
|
2023-06-10 11:43:44 +03:00
|
|
|
model=chat_message.model, anthropic_api_key=anthropic_api_key, temperature=chat_message.temperature, max_tokens_to_sample=chat_message.max_tokens),
|
2023-06-15 16:25:12 +03:00
|
|
|
vector_store.as_retriever(), verbose=False,
|
2023-06-10 11:43:44 +03:00
|
|
|
return_source_documents=with_sources,
|
|
|
|
max_tokens_limit=102400)
|
2023-06-13 12:00:15 +03:00
|
|
|
qa.combine_docs_chain = load_qa_chain(ChatAnthropic(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
|
2023-06-13 11:59:24 +03:00
|
|
|
|
2023-05-22 09:39:55 +03:00
|
|
|
return qa
|