mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-28 05:13:57 +03:00
8fbb4b2d91
* fix: gpt4all * fix: pyright * Update backend/llm/openai.py * fix: remove backend tag * fix: typing * feat: qa_base class * fix: pyright * fix: model_path not found
229 lines
7.6 KiB
Python
229 lines
7.6 KiB
Python
import asyncio
|
|
import json
|
|
from abc import abstractmethod, abstractproperty
|
|
from typing import AsyncIterable, Awaitable
|
|
|
|
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from langchain.llms.base import BaseLLM
|
|
from logger import get_logger
|
|
from models.chat import ChatHistory
|
|
from repository.chat.format_chat_history import format_chat_history
|
|
from repository.chat.get_chat_history import get_chat_history
|
|
from repository.chat.update_chat_history import update_chat_history
|
|
from repository.chat.update_message_by_id import update_message_by_id
|
|
from supabase.client import Client, create_client
|
|
from vectorstore.supabase import CustomSupabaseVectorStore
|
|
|
|
from .base import BaseBrainPicking
|
|
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class QABaseBrainPicking(BaseBrainPicking):
|
|
"""
|
|
Base class for the Brain Picking functionality using the Conversational Retrieval Chain (QA) from Langchain.
|
|
It is not designed to be used directly, but to be subclassed by other classes which use the QA chain.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
brain_id: str,
|
|
chat_id: str,
|
|
streaming: bool = False,
|
|
**kwargs,
|
|
) -> "QABaseBrainPicking": # pyright: ignore reportPrivateUsage=none
|
|
"""
|
|
Initialize the QA BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
|
:return: QABrainPicking instance
|
|
"""
|
|
super().__init__(
|
|
model=model,
|
|
brain_id=brain_id,
|
|
chat_id=chat_id,
|
|
streaming=streaming,
|
|
**kwargs,
|
|
)
|
|
|
|
@abstractproperty
|
|
def embeddings(self) -> OpenAIEmbeddings:
|
|
raise NotImplementedError("This property should be overridden in a subclass.")
|
|
|
|
@property
|
|
def supabase_client(self) -> Client:
|
|
return create_client(
|
|
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
|
|
)
|
|
|
|
@property
|
|
def vector_store(self) -> CustomSupabaseVectorStore:
|
|
return CustomSupabaseVectorStore(
|
|
self.supabase_client,
|
|
self.embeddings,
|
|
table_name="vectors",
|
|
brain_id=self.brain_id,
|
|
)
|
|
|
|
@property
|
|
def question_llm(self):
|
|
return self._create_llm(model=self.model, streaming=False)
|
|
|
|
@property
|
|
def doc_llm(self):
|
|
return self._create_llm(
|
|
model=self.model, streaming=self.streaming, callbacks=self.callbacks
|
|
)
|
|
|
|
@property
|
|
def question_generator(self) -> LLMChain:
|
|
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
|
|
|
|
@property
|
|
def doc_chain(self) -> LLMChain:
|
|
return load_qa_chain(
|
|
llm=self.doc_llm, chain_type="stuff"
|
|
) # pyright: ignore reportPrivateUsage=none
|
|
|
|
@property
|
|
def qa(self) -> ConversationalRetrievalChain:
|
|
return ConversationalRetrievalChain(
|
|
retriever=self.vector_store.as_retriever(),
|
|
question_generator=self.question_generator,
|
|
combine_docs_chain=self.doc_chain, # pyright: ignore reportPrivateUsage=none
|
|
verbose=True,
|
|
)
|
|
|
|
@abstractmethod
|
|
def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
|
|
"""
|
|
Determine the language model to be used.
|
|
:param model: Language model name to be used.
|
|
:param streaming: Whether to enable streaming of the model
|
|
:param callbacks: Callbacks to be used for streaming
|
|
:return: Language model instance
|
|
"""
|
|
|
|
def _call_chain(self, chain, question, history):
|
|
"""
|
|
Call a chain with a given question and history.
|
|
:param chain: The chain eg QA (ConversationalRetrievalChain)
|
|
:param question: The user prompt
|
|
:param history: The chat history from DB
|
|
:return: The answer.
|
|
"""
|
|
return chain(
|
|
{
|
|
"question": question,
|
|
"chat_history": history,
|
|
}
|
|
)
|
|
|
|
def generate_answer(self, question: str) -> ChatHistory:
|
|
"""
|
|
Generate an answer to a given question by interacting with the language model.
|
|
:param question: The question
|
|
:return: The generated answer.
|
|
"""
|
|
transformed_history = []
|
|
|
|
# Get the history from the database
|
|
history = get_chat_history(self.chat_id)
|
|
|
|
# Format the chat history into a list of tuples (human, ai)
|
|
transformed_history = format_chat_history(history)
|
|
|
|
# Generate the model response using the QA chain
|
|
model_response = self._call_chain(self.qa, question, transformed_history)
|
|
|
|
answer = model_response["answer"]
|
|
|
|
# Update chat history
|
|
chat_answer = update_chat_history(
|
|
chat_id=self.chat_id,
|
|
user_message=question,
|
|
assistant=answer,
|
|
)
|
|
|
|
return chat_answer
|
|
|
|
async def _acall_chain(self, chain, question, history):
|
|
"""
|
|
Call a chain with a given question and history.
|
|
:param chain: The chain eg QA (ConversationalRetrievalChain)
|
|
:param question: The user prompt
|
|
:param history: The chat history from DB
|
|
:return: The answer.
|
|
"""
|
|
return chain.acall(
|
|
{
|
|
"question": question,
|
|
"chat_history": history,
|
|
}
|
|
)
|
|
|
|
async def generate_stream(self, question: str) -> AsyncIterable:
|
|
"""
|
|
Generate a streaming answer to a given question by interacting with the language model.
|
|
:param question: The question
|
|
:return: An async iterable which generates the answer.
|
|
"""
|
|
|
|
history = get_chat_history(self.chat_id)
|
|
callback = self.callbacks[0]
|
|
|
|
transformed_history = []
|
|
|
|
# Format the chat history into a list of tuples (human, ai)
|
|
transformed_history = format_chat_history(history)
|
|
|
|
# Initialize a list to hold the tokens
|
|
response_tokens = []
|
|
|
|
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
|
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|
try:
|
|
await fn
|
|
except Exception as e:
|
|
logger.error(f"Caught exception: {e}")
|
|
finally:
|
|
event.set()
|
|
|
|
task = asyncio.create_task(
|
|
wrap_done(
|
|
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
|
|
self.qa, question, transformed_history
|
|
),
|
|
callback.done, # pyright: ignore reportPrivateUsage=none
|
|
)
|
|
)
|
|
|
|
streamed_chat_history = update_chat_history(
|
|
chat_id=self.chat_id,
|
|
user_message=question,
|
|
assistant="",
|
|
)
|
|
|
|
# Use the aiter method of the callback to stream the response with server-sent-events
|
|
async for token in callback.aiter(): # pyright: ignore reportPrivateUsage=none
|
|
logger.info("Token: %s", token)
|
|
|
|
# Add the token to the response_tokens list
|
|
response_tokens.append(token)
|
|
streamed_chat_history.assistant = token
|
|
|
|
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
|
|
|
|
await task
|
|
|
|
# Join the tokens to create the assistant's response
|
|
assistant = "".join(response_tokens)
|
|
|
|
update_message_by_id(
|
|
message_id=streamed_chat_history.message_id,
|
|
user_message=question,
|
|
assistant=assistant,
|
|
)
|