mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-24 20:03:41 +03:00
feat(brainpicking): new class
This commit is contained in:
parent
17aaf18d61
commit
d42f14f431
@ -17,77 +17,59 @@ from langchain.vectorstores import SupabaseVectorStore
|
||||
from llm.prompt import LANGUAGE_PROMPT
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from models.chats import ChatMessage
|
||||
from pydantic import BaseModel, BaseSettings
|
||||
from supabase import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
|
||||
class BrainSettings(BaseSettings):
|
||||
openai_api_key: str
|
||||
anthropic_api_key: str
|
||||
supabase_url: str
|
||||
supabase_service_key: str
|
||||
|
||||
|
||||
class AnswerConversationBufferMemory(ConversationBufferMemory):
|
||||
"""ref https://github.com/hwchase17/langchain/issues/5630#issuecomment-1574222564"""
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
return super(AnswerConversationBufferMemory, self).save_context(
|
||||
inputs, {'response': outputs['answer']})
|
||||
|
||||
|
||||
def get_environment_variables():
|
||||
'''Get the environment variables.'''
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
supabase_url = os.getenv("SUPABASE_URL")
|
||||
supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
|
||||
|
||||
return openai_api_key, anthropic_api_key, supabase_url, supabase_key
|
||||
|
||||
def create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key):
|
||||
'''Create the clients and embeddings.'''
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
||||
supabase_client = create_client(supabase_url, supabase_key)
|
||||
|
||||
return supabase_client, embeddings
|
||||
|
||||
def get_chat_history(inputs) -> str:
|
||||
res = []
|
||||
for human, ai in inputs:
|
||||
res.append(f"{human}:{ai}\n")
|
||||
return "\n".join(res)
|
||||
|
||||
def get_qa_llm(chat_message: ChatMessage, user_id: str, user_openai_api_key: str, with_sources: bool = False):
|
||||
'''Get the question answering language model.'''
|
||||
openai_api_key, anthropic_api_key, supabase_url, supabase_key = get_environment_variables()
|
||||
class BrainPicking(BaseModel):
|
||||
""" Class that allows the user to pick a brain. """
|
||||
llm_name: str = "gpt-3.5-turbo"
|
||||
settings = BrainSettings()
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
supabase_client: Client = None
|
||||
vector_store: CustomSupabaseVectorStore = None
|
||||
llm: ChatOpenAI = None
|
||||
question_generator: LLMChain = None
|
||||
doc_chain: ConversationalRetrievalChain = None
|
||||
|
||||
'''User can override the openai_api_key'''
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def init(self, model: str, user_id: str) -> "BrainPicking":
|
||||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
||||
self.supabase_client = create_client(self.settings.supabase_url, self.settings.supabase_service_key)
|
||||
self.vector_store = CustomSupabaseVectorStore(
|
||||
self.supabase_client, self.embeddings, table_name="vectors", user_id=user_id)
|
||||
self.llm = ChatOpenAI(temperature=0, model_name=model)
|
||||
self.question_generator = LLMChain(llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT)
|
||||
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
||||
return self
|
||||
|
||||
def get_qa(self, chat_message: ChatMessage, user_openai_api_key) -> ConversationalRetrievalChain:
|
||||
if user_openai_api_key is not None and user_openai_api_key != "":
|
||||
openai_api_key = user_openai_api_key
|
||||
|
||||
supabase_client, embeddings = create_clients_and_embeddings(openai_api_key, supabase_url, supabase_key)
|
||||
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="vectors", user_id=user_id)
|
||||
|
||||
|
||||
|
||||
qa = None
|
||||
|
||||
if chat_message.model.startswith("gpt"):
|
||||
llm = ChatOpenAI(temperature=0, model_name=chat_message.model)
|
||||
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
||||
doc_chain = load_qa_chain(llm, chain_type="stuff")
|
||||
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
qa = ConversationalRetrievalChain(
|
||||
retriever=vector_store.as_retriever(),
|
||||
max_tokens_limit=chat_message.max_tokens, question_generator=question_generator,
|
||||
combine_docs_chain=doc_chain, get_chat_history=get_chat_history)
|
||||
elif chat_message.model.startswith("vertex"):
|
||||
qa = ConversationalRetrievalChain.from_llm(
|
||||
ChatVertexAI(), vector_store.as_retriever(), verbose=True,
|
||||
return_source_documents=with_sources, max_tokens_limit=1024,question_generator=question_generator,
|
||||
combine_docs_chain=doc_chain)
|
||||
elif anthropic_api_key and chat_message.model.startswith("claude"):
|
||||
qa = ConversationalRetrievalChain.from_llm(
|
||||
ChatAnthropic(
|
||||
model=chat_message.model, anthropic_api_key=anthropic_api_key, temperature=chat_message.temperature, max_tokens_to_sample=chat_message.max_tokens),
|
||||
vector_store.as_retriever(), verbose=False,
|
||||
return_source_documents=with_sources,
|
||||
max_tokens_limit=102400)
|
||||
qa.combine_docs_chain = load_qa_chain(ChatAnthropic(), chain_type="stuff", prompt=LANGUAGE_PROMPT.QA_PROMPT)
|
||||
|
||||
retriever=self.vector_store.as_retriever(),
|
||||
max_tokens_limit=chat_message.max_tokens, question_generator=self.question_generator,
|
||||
combine_docs_chain=self.doc_chain, get_chat_history=get_chat_history)
|
||||
return qa
|
||||
|
@ -1,6 +1,6 @@
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from llm.qa import get_qa_llm
|
||||
from llm.qa import BrainPicking
|
||||
from llm.summarization import llm_evaluate_summaries, llm_summerize
|
||||
from logger import get_logger
|
||||
from models.chats import ChatMessage
|
||||
@ -49,7 +49,9 @@ def similarity_search(commons: CommonsDep, query, table='match_summaries', top_k
|
||||
return summaries.data
|
||||
|
||||
def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_openai_api_key:str):
|
||||
qa = get_qa_llm(chat_message, email, user_openai_api_key)
|
||||
|
||||
Brain = BrainPicking().init(chat_message.model, email)
|
||||
qa = Brain.get_qa(chat_message, user_openai_api_key)
|
||||
|
||||
if chat_message.use_summarization:
|
||||
# 1. get summaries from the vector store based on question
|
||||
|
Loading…
Reference in New Issue
Block a user