diff --git a/backend/core/llm/__init__.py b/backend/core/llm/__init__.py index c2262f0fd..f17150037 100644 --- a/backend/core/llm/__init__.py +++ b/backend/core/llm/__init__.py @@ -1,9 +1,11 @@ from .base import BaseBrainPicking from .qa_base import QABaseBrainPicking from .openai import OpenAIBrainPicking +from .qa_headless import HeadlessQA __all__ = [ "BaseBrainPicking", "QABaseBrainPicking", "OpenAIBrainPicking", + "HeadlessQA" ] diff --git a/backend/core/llm/qa_headless.py b/backend/core/llm/qa_headless.py new file mode 100644 index 000000000..74889ac85 --- /dev/null +++ b/backend/core/llm/qa_headless.py @@ -0,0 +1,207 @@ +import asyncio +import json +from uuid import UUID + +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler +from langchain.chat_models import ChatOpenAI +from langchain.chains import LLMChain +from langchain.llms.base import BaseLLM +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from repository.chat.update_message_by_id import update_message_by_id +from models.databases.supabase.chats import CreateChatHistory +from repository.chat.format_chat_history import format_chat_history +from repository.chat.get_chat_history import get_chat_history +from repository.chat.update_chat_history import update_chat_history +from repository.chat.format_chat_history import format_history_to_openai_mesages +from logger import get_logger +from models.chats import ChatQuestion +from repository.chat.get_chat_history import GetChatHistoryOutput + + +from pydantic import BaseModel + +from typing import AsyncIterable, Awaitable, List + +logger = get_logger(__name__) +SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer." + + +class HeadlessQA(BaseModel): + model: str = None # type: ignore + temperature: float = 0.0 + max_tokens: int = 256 + user_openai_api_key: str = None # type: ignore + openai_api_key: str = None # type: ignore + streaming: bool = False + chat_id: str = None # type: ignore + callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore + + def _determine_api_key(self, openai_api_key, user_openai_api_key): + """If user provided an API key, use it.""" + if user_openai_api_key is not None: + return user_openai_api_key + else: + return openai_api_key + + def _determine_streaming(self, model: str, streaming: bool) -> bool: + """If the model name allows for streaming and streaming is declared, set streaming to True.""" + return streaming + + def _determine_callback_array( + self, streaming + ) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none + """If streaming is set, set the AsyncIteratorCallbackHandler as the only callback.""" + if streaming: + return [ + AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none + ] + + def __init__(self, **data): + super().__init__(**data) + + self.openai_api_key = self._determine_api_key( + self.openai_api_key, self.user_openai_api_key + ) + self.streaming = self._determine_streaming( + self.model, self.streaming + ) # pyright: ignore reportPrivateUsage=none + self.callbacks = self._determine_callback_array( + self.streaming + ) # pyright: ignore reportPrivateUsage=none + + def _create_llm( + self, model, temperature=0, streaming=False, callbacks=None + ) -> BaseLLM: + """ + Determine the language model to be used. + :param model: Language model name to be used. + :param streaming: Whether to enable streaming of the model + :param callbacks: Callbacks to be used for streaming + :return: Language model instance + """ + return ChatOpenAI( + temperature=temperature, + model=model, + streaming=streaming, + verbose=True, + callbacks=callbacks, + openai_api_key=self.openai_api_key, + ) # pyright: ignore reportPrivateUsage=none + + def _create_prompt_template(self): + messages = [ + HumanMessagePromptTemplate.from_template("{question}"), + ] + CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) + return CHAT_PROMPT + + def generate_answer( + self, chat_id: UUID, question: ChatQuestion + ) -> GetChatHistoryOutput: + transformed_history = format_chat_history(get_chat_history(self.chat_id)) + messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question) + answering_llm = self._create_llm( + model=self.model, streaming=False, callbacks=self.callbacks + ) + model_prediction = answering_llm.predict_messages(messages) # pyright: ignore reportPrivateUsage=none + answer = model_prediction.content + + new_chat = update_chat_history( + CreateChatHistory( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": answer, + "brain_id": None, + "prompt_id": None, + } + ) + ) + + return GetChatHistoryOutput( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": answer, + "message_time": new_chat.message_time, + "prompt_title": None, + "brain_name": None, + "message_id": new_chat.message_id, + } + ) + + async def generate_stream( + self, chat_id: UUID, question: ChatQuestion + ) -> AsyncIterable: + callback = AsyncIteratorCallbackHandler() + self.callbacks = [callback] + + transformed_history = format_chat_history(get_chat_history(self.chat_id)) + messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question) + answering_llm = self._create_llm( + model=self.model, streaming=True, callbacks=self.callbacks + ) + + CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) + headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT) + + response_tokens = [] + + async def wrap_done(fn: Awaitable, event: asyncio.Event): + try: + await fn + except Exception as e: + logger.error(f"Caught exception: {e}") + finally: + event.set() + run = asyncio.create_task( + wrap_done( + headlessChain.acall({}), + callback.done, + ), + ) + + streamed_chat_history = update_chat_history( + CreateChatHistory( + **{ + "chat_id": chat_id, + "user_message": question.question, + "assistant": "", + "brain_id": None, + "prompt_id": None, + } + ) + ) + + streamed_chat_history = GetChatHistoryOutput( + **{ + "chat_id": str(chat_id), + "message_id": streamed_chat_history.message_id, + "message_time": streamed_chat_history.message_time, + "user_message": question.question, + "assistant": "", + "prompt_title": None, + "brain_name": None, + } + ) + + async for token in callback.aiter(): + logger.info("Token: %s", token) # type: ignore + response_tokens.append(token) # type: ignore + streamed_chat_history.assistant = token # type: ignore + yield f"data: {json.dumps(streamed_chat_history.dict())}" + + await run + assistant = "".join(response_tokens) + + update_message_by_id( + message_id=str(streamed_chat_history.message_id), + user_message=question.question, + assistant=assistant, + ) + + class Config: + arbitrary_types_allowed = True diff --git a/backend/core/repository/chat/format_chat_history.py b/backend/core/repository/chat/format_chat_history.py index 2f91dd37f..9941a7ca5 100644 --- a/backend/core/repository/chat/format_chat_history.py +++ b/backend/core/repository/chat/format_chat_history.py @@ -1,4 +1,19 @@ -def format_chat_history(history) -> list[tuple[str, str]]: +from typing import List, Tuple +from langchain.schema import AIMessage, HumanMessage, SystemMessage + + +def format_chat_history(history) -> List[Tuple[str, str]]: """Format the chat history into a list of tuples (human, ai)""" return [(chat.user_message, chat.assistant) for chat in history] + + +def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]: + """Format the chat history into a list of Base Messages""" + messages = [] + messages.append(SystemMessage(content=system_message)) + for human, ai in tuple_history: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) + messages.append(HumanMessage(content=question)) + return messages diff --git a/backend/core/routes/chat_routes.py b/backend/core/routes/chat_routes.py index 1da1fbcbf..aaf3dc9a4 100644 --- a/backend/core/routes/chat_routes.py +++ b/backend/core/routes/chat_routes.py @@ -7,6 +7,7 @@ from venv import logger from auth import AuthBearer, get_current_user from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse +from llm.qa_headless import HeadlessQA from llm.openai import OpenAIBrainPicking from models.brains import Brain from models.brain_entity import BrainEntity @@ -16,9 +17,6 @@ from models.databases.supabase.supabase import SupabaseDB from models.settings import LLMSettings, get_supabase_db from models.users import User from repository.brain.get_brain_details import get_brain_details -from repository.brain.get_default_user_brain_or_create_new import ( - get_default_user_brain_or_create_new, -) from repository.chat.create_chat import CreateChatProperties, create_chat from repository.chat.get_chat_by_id import get_chat_by_id from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history @@ -190,17 +188,24 @@ async def create_question_handler( check_user_limit(current_user) LLMSettings() - if not brain_id: - brain_id = get_default_user_brain_or_create_new(current_user).brain_id - - gpt_answer_generator = OpenAIBrainPicking( - chat_id=str(chat_id), - model=chat_question.model, - max_tokens=chat_question.max_tokens, - temperature=chat_question.temperature, - brain_id=str(brain_id), - user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none - ) + gpt_answer_generator: HeadlessQA | OpenAIBrainPicking + if brain_id: + gpt_answer_generator = OpenAIBrainPicking( + chat_id=str(chat_id), + model=chat_question.model, + max_tokens=chat_question.max_tokens, + temperature=chat_question.temperature, + brain_id=str(brain_id), + user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none + ) + else: + gpt_answer_generator = HeadlessQA( + model=chat_question.model, + temperature=chat_question.temperature, + max_tokens=chat_question.max_tokens, + user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none + chat_id=str(chat_id), + ) chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question) @@ -259,18 +264,26 @@ async def create_stream_question_handler( try: logger.info(f"Streaming request for {chat_question.model}") check_user_limit(current_user) - if not brain_id: - brain_id = get_default_user_brain_or_create_new(current_user).brain_id - - gpt_answer_generator = OpenAIBrainPicking( - chat_id=str(chat_id), - model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo", - max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0, - temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256, - brain_id=str(brain_id), - user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none - streaming=True, - ) + gpt_answer_generator: HeadlessQA | OpenAIBrainPicking + if brain_id: + gpt_answer_generator = OpenAIBrainPicking( + chat_id=str(chat_id), + model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo", + max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0, + temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256, + brain_id=str(brain_id), + user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none + streaming=True, + ) + else: + gpt_answer_generator = HeadlessQA( + model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo", + temperature=chat_question.temperature if current_user.user_openai_api_key else 256, + max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0, + user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none + chat_id=str(chat_id), + streaming=True, + ) print("streaming") return StreamingResponse( diff --git a/frontend/app/chat/[chatId]/hooks/useChat.ts b/frontend/app/chat/[chatId]/hooks/useChat.ts index 65c8255a1..7bb366625 100644 --- a/frontend/app/chat/[chatId]/hooks/useChat.ts +++ b/frontend/app/chat/[chatId]/hooks/useChat.ts @@ -1,18 +1,18 @@ /* eslint-disable max-lines */ -import { AxiosError } from "axios"; -import { useParams } from "next/navigation"; -import { useState } from "react"; -import { useTranslation } from "react-i18next"; +import { AxiosError } from 'axios'; +import { useParams } from 'next/navigation'; +import { useState } from 'react'; +import { useTranslation } from 'react-i18next'; -import { getChatConfigFromLocalStorage } from "@/lib/api/chat/chat.local"; -import { useChatApi } from "@/lib/api/chat/useChatApi"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"; -import { useToast } from "@/lib/hooks"; -import { useEventTracking } from "@/services/analytics/useEventTracking"; +import { getChatConfigFromLocalStorage } from '@/lib/api/chat/chat.local'; +import { useChatApi } from '@/lib/api/chat/useChatApi'; +import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext'; +import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext'; +import { useToast } from '@/lib/hooks'; +import { useEventTracking } from '@/services/analytics/useEventTracking'; -import { useQuestion } from "./useQuestion"; -import { ChatQuestion } from "../types"; +import { useQuestion } from './useQuestion'; +import { ChatQuestion } from '../types'; // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useChat = () => { @@ -29,7 +29,7 @@ export const useChat = () => { const { createChat } = useChatApi(); const { addStreamQuestion } = useQuestion(); - const { t } = useTranslation(["chat"]); + const { t } = useTranslation(['chat']); const addQuestion = async (question: string, callback?: () => void) => { try { @@ -39,22 +39,13 @@ export const useChat = () => { //if chatId is not set, create a new chat. Chat name is from the first question if (currentChatId === undefined) { - const chatName = question.split(" ").slice(0, 3).join(" "); + const chatName = question.split(' ').slice(0, 3).join(' '); const chat = await createChat(chatName); currentChatId = chat.chat_id; setChatId(currentChatId); } - if (currentBrain?.id === undefined) { - publish({ - variant: "danger", - text: t("missing_brain"), - }); - - return; - } - - void track("QUESTION_ASKED"); + void track('QUESTION_ASKED'); const chatConfig = getChatConfigFromLocalStorage(currentChatId); const chatQuestion: ChatQuestion = { @@ -62,7 +53,7 @@ export const useChat = () => { question, temperature: chatConfig?.temperature, max_tokens: chatConfig?.maxTokens, - brain_id: currentBrain.id, + brain_id: currentBrain?.id, }; await addStreamQuestion(currentChatId, chatQuestion); @@ -73,16 +64,16 @@ export const useChat = () => { if ((error as AxiosError).response?.status === 429) { publish({ - variant: "danger", - text: t("limit_reached", { ns: "chat" }), + variant: 'danger', + text: t('limit_reached', { ns: 'chat' }), }); return; } publish({ - variant: "danger", - text: t("error_occurred", { ns: "chat" }), + variant: 'danger', + text: t('error_occurred', { ns: 'chat' }), }); } finally { setGeneratingAnswer(false); diff --git a/frontend/app/chat/[chatId]/hooks/useQuestion.ts b/frontend/app/chat/[chatId]/hooks/useQuestion.ts index a9a31591e..3337fa995 100644 --- a/frontend/app/chat/[chatId]/hooks/useQuestion.ts +++ b/frontend/app/chat/[chatId]/hooks/useQuestion.ts @@ -1,12 +1,12 @@ /* eslint-disable max-lines */ -import { useTranslation } from "react-i18next"; +import { useTranslation } from 'react-i18next'; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"; -import { useFetch } from "@/lib/hooks"; +import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext'; +import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext'; +import { useFetch } from '@/lib/hooks'; -import { ChatHistory, ChatQuestion } from "../types"; +import { ChatHistory, ChatQuestion } from '../types'; interface UseChatService { addStreamQuestion: ( @@ -20,12 +20,12 @@ export const useQuestion = (): UseChatService => { const { updateStreamingHistory } = useChatContext(); const { currentBrain } = useBrainContext(); - const { t } = useTranslation(["chat"]); + const { t } = useTranslation(['chat']); const handleStream = async ( reader: ReadableStreamDefaultReader ): Promise => { - const decoder = new TextDecoder("utf-8"); + const decoder = new TextDecoder('utf-8'); const handleStreamRecursively = async () => { const { done, value } = await reader.read(); @@ -37,7 +37,7 @@ export const useQuestion = (): UseChatService => { const dataStrings = decoder .decode(value) .trim() - .split("data: ") + .split('data: ') .filter(Boolean); dataStrings.forEach((data) => { @@ -45,7 +45,7 @@ export const useQuestion = (): UseChatService => { const parsedData = JSON.parse(data) as ChatHistory; updateStreamingHistory(parsedData); } catch (error) { - console.error(t("errorParsingData", { ns: "chat" }), error); + console.error(t('errorParsingData', { ns: 'chat' }), error); } }); @@ -59,30 +59,27 @@ export const useQuestion = (): UseChatService => { chatId: string, chatQuestion: ChatQuestion ): Promise => { - if (currentBrain?.id === undefined) { - throw new Error(t("noCurrentBrain", { ns: "chat" })); - } const headers = { - "Content-Type": "application/json", - Accept: "text/event-stream", + 'Content-Type': 'application/json', + Accept: 'text/event-stream', }; const body = JSON.stringify(chatQuestion); - console.log("Calling API..."); + console.log('Calling API...'); try { const response = await fetchInstance.post( - `/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`, + `/chat/${chatId}/question/stream?brain_id=${currentBrain?.id ?? ''}`, body, headers ); if (response.body === null) { - throw new Error(t("resposeBodyNull", { ns: "chat" })); + throw new Error(t('resposeBodyNull', { ns: 'chat' })); } - console.log(t("receivedResponse"), response); + console.log(t('receivedResponse'), response); await handleStream(response.body.getReader()); } catch (error) { - console.error(t("errorCallingAPI", { ns: "chat" }), error); + console.error(t('errorCallingAPI', { ns: 'chat' }), error); } };