quivr/backend/llm/qa_headless.py
2023-08-25 14:03:57 +02:00

234 lines
7.7 KiB
Python

import asyncio
import json
from typing import AsyncIterable, Awaitable, List, Optional
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from models.prompt import Prompt
from pydantic import BaseModel
from repository.chat import (
GetChatHistoryOutput,
format_chat_history,
format_history_to_openai_mesages,
get_chat_history,
update_chat_history,
update_message_by_id,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
class HeadlessQA(BaseModel):
model: str
temperature: float = 0.0
max_tokens: int = 256
user_openai_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
streaming: bool = False
chat_id: str
callbacks: Optional[List[AsyncIteratorCallbackHandler]] = None
prompt_id: Optional[UUID] = None
def _determine_api_key(self, openai_api_key, user_openai_api_key):
"""If user provided an API key, use it."""
if user_openai_api_key is not None:
return user_openai_api_key
else:
return openai_api_key
def _determine_streaming(self, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]:
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [AsyncIteratorCallbackHandler()]
else:
return []
def __init__(self, **data):
super().__init__(**data)
print("in HeadlessQA")
self.openai_api_key = self._determine_api_key(
self.openai_api_key, self.user_openai_api_key
)
self.streaming = self._determine_streaming(self.streaming)
self.callbacks = self._determine_callback_array(self.streaming)
@property
def prompt_to_use(self) -> Optional[Prompt]:
return get_prompt_to_use(None, self.prompt_id)
@property
def prompt_to_use_id(self) -> Optional[UUID]:
return get_prompt_to_use_id(None, self.prompt_id)
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseChatModel:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
)
def _create_prompt_template(self):
messages = [
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages(
transformed_history, prompt_content, question.question
)
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
model_prediction = answering_llm.predict_messages(messages)
answer = model_prediction.content
new_chat = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": None,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": None,
"message_id": new_chat.message_id,
}
)
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
transformed_history = format_chat_history(get_chat_history(self.chat_id))
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages(
transformed_history, prompt_content, question.question
)
answering_llm = self._create_llm(
model=self.model, streaming=True, callbacks=self.callbacks
)
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT)
response_tokens = []
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
run = asyncio.create_task(
wrap_done(
headlessChain.acall({}),
callback.done,
),
)
streamed_chat_history = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": None,
"prompt_id": self.prompt_to_use_id,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": None,
}
)
async for token in callback.aiter():
logger.info("Token: %s", token)
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.dict())}"
await run
assistant = "".join(response_tokens)
update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
)
class Config:
arbitrary_types_allowed = True