quivr/backend/llm/base.py
Matt e931d29017
feat: new classes (#499)
* feat: new classes

* chore: private llm refactor and enable streaming

* fix: validation errors in pytest

* fix: default brain behavior

* Update backend/llm/base.py

* chore: remove 🪦

* fix: remove 🪦
2023-07-04 16:56:54 +01:00

174 lines
6.8 KiB
Python

from abc import abstractmethod
from typing import AsyncIterable, List
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from utils.constants import streaming_compatible_models
logger = get_logger(__name__)
class BaseBrainPicking(BaseModel):
"""
Base Class for BrainPicking. Allows you to interact with LLMs (large language models)
Use this class to define abstract methods and methods and properties common to all classes.
"""
# Instantiate settings
brain_settings = BrainSettings()
# Default class attributes
model: str = None
temperature: float = 0.0
chat_id: str = None
brain_id: str = None
max_tokens: int = 256
user_openai_api_key: str = None
streaming: bool = False
openai_api_key: str = None
callbacks: List[AsyncCallbackHandler] = None
def _determine_api_key(self, openai_api_key, user_openai_api_key):
"""If user provided an API key, use it."""
if user_openai_api_key is not None:
return user_openai_api_key
else:
return openai_api_key
def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
if model in streaming_compatible_models and streaming:
return True
if model not in streaming_compatible_models and streaming:
logger.warning(
f"Streaming is not compatible with {model}. Streaming will be set to False."
)
return False
else:
return False
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]:
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [AsyncIteratorCallbackHandler]
def __init__(self, **data):
super().__init__(**data)
self.openai_api_key = self._determine_api_key(
self.brain_settings.openai_api_key, self.user_openai_api_key
)
self.streaming = self._determine_streaming(self.model, self.streaming)
self.callbacks = self._determine_callback_array(self.streaming)
class Config:
"""Configuration of the Pydantic Object"""
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
@abstractmethod
def _create_llm(self, model, streaming=False, callbacks=None) -> LLM:
"""
Determine and construct the language model.
:param model: Language model name to be used.
:return: Language model instance
This method should take into account the following:
- Whether the model is streaming compatible
- Whether the model is private
- Whether the model should use an openai api key and use the _determine_api_key method
"""
@abstractmethod
def _create_question_chain(self, model) -> LLMChain:
"""
Determine and construct the question chain.
:param model: Language model name to be used.
:return: Question chain instance
This method should take into account the following:
- Which prompt to use (normally CONDENSE_QUESTION_PROMPT)
"""
@abstractmethod
def _create_doc_chain(self, model) -> LLMChain:
"""
Determine and construct the document chain.
:param model Language model name to be used.
:return: Document chain instance
This method should take into account the following:
- chain_type (normally "stuff")
- Whether the model is streaming compatible and/or streaming is set (determine_streaming).
"""
@abstractmethod
def _create_qa(
self, question_chain, document_chain
) -> ConversationalRetrievalChain:
"""
Constructs a conversational retrieval chain .
:param question_chain
:param document_chain
:return: ConversationalRetrievalChain instance
"""
@abstractmethod
def _call_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
async def _acall_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg qa (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
raise NotImplementedError(
"Async generation not implemented for this BrainPicking Class."
)
@abstractmethod
def generate_answer(self, question: str) -> str:
"""
Generate an answer to a given question using QA Chain.
:param question: The question
:return: The generated answer.
This function should also call: _create_qa, get_chat_history and format_chat_history.
It should also update the chat_history in the DB.
"""
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question using QA Chain.
:param question: The question
:return: An async iterable which generates the answer.
This function has to do some other things:
- Update the chat history in the DB with the chat details(chat_id, question) -> Return a message_id and timestamp
- Use the _acall_chain method inside create_task from asyncio to run the process on a child thread.
- Append each token to the chat_history object from the db and yield it from the function
- Append each token from the callback to an answer string -> Used to update chat history in DB (update_message_by_id)
"""
raise NotImplementedError(
"Async generation not implemented for this BrainPicking Class."
)