feat(Unplug): chatting without brain streaming (#970)

* feat(Unplug): Adds new basic headless llm

* feat(Unplug): adds chatting without brain option when no streaming

* feat(Unplug): adds chatting without brain option when streaming
This commit is contained in:
Stepan Lebedev 2023-08-18 10:32:22 +02:00 committed by GitHub
parent 7281fd905a
commit 600ff1ede0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 300 additions and 75 deletions

View File

@ -1,9 +1,11 @@
from .base import BaseBrainPicking from .base import BaseBrainPicking
from .qa_base import QABaseBrainPicking from .qa_base import QABaseBrainPicking
from .openai import OpenAIBrainPicking from .openai import OpenAIBrainPicking
from .qa_headless import HeadlessQA
__all__ = [ __all__ = [
"BaseBrainPicking", "BaseBrainPicking",
"QABaseBrainPicking", "QABaseBrainPicking",
"OpenAIBrainPicking", "OpenAIBrainPicking",
"HeadlessQA"
] ]

View File

@ -0,0 +1,207 @@
import asyncio
import json
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from repository.chat.update_message_by_id import update_message_by_id
from models.databases.supabase.chats import CreateChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.format_chat_history import format_history_to_openai_mesages
from logger import get_logger
from models.chats import ChatQuestion
from repository.chat.get_chat_history import GetChatHistoryOutput
from pydantic import BaseModel
from typing import AsyncIterable, Awaitable, List
logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
class HeadlessQA(BaseModel):
model: str = None # type: ignore
temperature: float = 0.0
max_tokens: int = 256
user_openai_api_key: str = None # type: ignore
openai_api_key: str = None # type: ignore
streaming: bool = False
chat_id: str = None # type: ignore
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore
def _determine_api_key(self, openai_api_key, user_openai_api_key):
"""If user provided an API key, use it."""
if user_openai_api_key is not None:
return user_openai_api_key
else:
return openai_api_key
def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]
def __init__(self, **data):
super().__init__(**data)
self.openai_api_key = self._determine_api_key(
self.openai_api_key, self.user_openai_api_key
)
self.streaming = self._determine_streaming(
self.model, self.streaming
) # pyright: ignore reportPrivateUsage=none
self.callbacks = self._determine_callback_array(
self.streaming
) # pyright: ignore reportPrivateUsage=none
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
messages = [
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
model_prediction = answering_llm.predict_messages(messages) # pyright: ignore reportPrivateUsage=none
answer = model_prediction.content
new_chat = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": None,
"prompt_id": None,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": None,
"brain_name": None,
"message_id": new_chat.message_id,
}
)
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
transformed_history = format_chat_history(get_chat_history(self.chat_id))
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
answering_llm = self._create_llm(
model=self.model, streaming=True, callbacks=self.callbacks
)
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT)
response_tokens = []
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
run = asyncio.create_task(
wrap_done(
headlessChain.acall({}),
callback.done,
),
)
streamed_chat_history = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": None,
"prompt_id": None,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": None,
"brain_name": None,
}
)
async for token in callback.aiter():
logger.info("Token: %s", token) # type: ignore
response_tokens.append(token) # type: ignore
streamed_chat_history.assistant = token # type: ignore
yield f"data: {json.dumps(streamed_chat_history.dict())}"
await run
assistant = "".join(response_tokens)
update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
)
class Config:
arbitrary_types_allowed = True

View File

@ -1,4 +1,19 @@
def format_chat_history(history) -> list[tuple[str, str]]: from typing import List, Tuple
from langchain.schema import AIMessage, HumanMessage, SystemMessage
def format_chat_history(history) -> List[Tuple[str, str]]:
"""Format the chat history into a list of tuples (human, ai)""" """Format the chat history into a list of tuples (human, ai)"""
return [(chat.user_message, chat.assistant) for chat in history] return [(chat.user_message, chat.assistant) for chat in history]
def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))
for human, ai in tuple_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=question))
return messages

View File

@ -7,6 +7,7 @@ from venv import logger
from auth import AuthBearer, get_current_user from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from llm.qa_headless import HeadlessQA
from llm.openai import OpenAIBrainPicking from llm.openai import OpenAIBrainPicking
from models.brains import Brain from models.brains import Brain
from models.brain_entity import BrainEntity from models.brain_entity import BrainEntity
@ -16,9 +17,6 @@ from models.databases.supabase.supabase import SupabaseDB
from models.settings import LLMSettings, get_supabase_db from models.settings import LLMSettings, get_supabase_db
from models.users import User from models.users import User
from repository.brain.get_brain_details import get_brain_details from repository.brain.get_brain_details import get_brain_details
from repository.brain.get_default_user_brain_or_create_new import (
get_default_user_brain_or_create_new,
)
from repository.chat.create_chat import CreateChatProperties, create_chat from repository.chat.create_chat import CreateChatProperties, create_chat
from repository.chat.get_chat_by_id import get_chat_by_id from repository.chat.get_chat_by_id import get_chat_by_id
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
@ -190,17 +188,24 @@ async def create_question_handler(
check_user_limit(current_user) check_user_limit(current_user)
LLMSettings() LLMSettings()
if not brain_id: gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
brain_id = get_default_user_brain_or_create_new(current_user).brain_id if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
gpt_answer_generator = OpenAIBrainPicking( chat_id=str(chat_id),
chat_id=str(chat_id), model=chat_question.model,
model=chat_question.model, max_tokens=chat_question.max_tokens,
max_tokens=chat_question.max_tokens, temperature=chat_question.temperature,
temperature=chat_question.temperature, brain_id=str(brain_id),
brain_id=str(brain_id), user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none )
) else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model,
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
)
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question) chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
@ -259,18 +264,26 @@ async def create_stream_question_handler(
try: try:
logger.info(f"Streaming request for {chat_question.model}") logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user) check_user_limit(current_user)
if not brain_id: gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
brain_id = get_default_user_brain_or_create_new(current_user).brain_id if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
gpt_answer_generator = OpenAIBrainPicking( chat_id=str(chat_id),
chat_id=str(chat_id), model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo", max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0, temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256, brain_id=str(brain_id),
brain_id=str(brain_id), user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none streaming=True,
streaming=True, )
) else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
)
print("streaming") print("streaming")
return StreamingResponse( return StreamingResponse(

View File

@ -1,18 +1,18 @@
/* eslint-disable max-lines */ /* eslint-disable max-lines */
import { AxiosError } from "axios"; import { AxiosError } from 'axios';
import { useParams } from "next/navigation"; import { useParams } from 'next/navigation';
import { useState } from "react"; import { useState } from 'react';
import { useTranslation } from "react-i18next"; import { useTranslation } from 'react-i18next';
import { getChatConfigFromLocalStorage } from "@/lib/api/chat/chat.local"; import { getChatConfigFromLocalStorage } from '@/lib/api/chat/chat.local';
import { useChatApi } from "@/lib/api/chat/useChatApi"; import { useChatApi } from '@/lib/api/chat/useChatApi';
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext';
import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"; import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext';
import { useToast } from "@/lib/hooks"; import { useToast } from '@/lib/hooks';
import { useEventTracking } from "@/services/analytics/useEventTracking"; import { useEventTracking } from '@/services/analytics/useEventTracking';
import { useQuestion } from "./useQuestion"; import { useQuestion } from './useQuestion';
import { ChatQuestion } from "../types"; import { ChatQuestion } from '../types';
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChat = () => { export const useChat = () => {
@ -29,7 +29,7 @@ export const useChat = () => {
const { createChat } = useChatApi(); const { createChat } = useChatApi();
const { addStreamQuestion } = useQuestion(); const { addStreamQuestion } = useQuestion();
const { t } = useTranslation(["chat"]); const { t } = useTranslation(['chat']);
const addQuestion = async (question: string, callback?: () => void) => { const addQuestion = async (question: string, callback?: () => void) => {
try { try {
@ -39,22 +39,13 @@ export const useChat = () => {
//if chatId is not set, create a new chat. Chat name is from the first question //if chatId is not set, create a new chat. Chat name is from the first question
if (currentChatId === undefined) { if (currentChatId === undefined) {
const chatName = question.split(" ").slice(0, 3).join(" "); const chatName = question.split(' ').slice(0, 3).join(' ');
const chat = await createChat(chatName); const chat = await createChat(chatName);
currentChatId = chat.chat_id; currentChatId = chat.chat_id;
setChatId(currentChatId); setChatId(currentChatId);
} }
if (currentBrain?.id === undefined) { void track('QUESTION_ASKED');
publish({
variant: "danger",
text: t("missing_brain"),
});
return;
}
void track("QUESTION_ASKED");
const chatConfig = getChatConfigFromLocalStorage(currentChatId); const chatConfig = getChatConfigFromLocalStorage(currentChatId);
const chatQuestion: ChatQuestion = { const chatQuestion: ChatQuestion = {
@ -62,7 +53,7 @@ export const useChat = () => {
question, question,
temperature: chatConfig?.temperature, temperature: chatConfig?.temperature,
max_tokens: chatConfig?.maxTokens, max_tokens: chatConfig?.maxTokens,
brain_id: currentBrain.id, brain_id: currentBrain?.id,
}; };
await addStreamQuestion(currentChatId, chatQuestion); await addStreamQuestion(currentChatId, chatQuestion);
@ -73,16 +64,16 @@ export const useChat = () => {
if ((error as AxiosError).response?.status === 429) { if ((error as AxiosError).response?.status === 429) {
publish({ publish({
variant: "danger", variant: 'danger',
text: t("limit_reached", { ns: "chat" }), text: t('limit_reached', { ns: 'chat' }),
}); });
return; return;
} }
publish({ publish({
variant: "danger", variant: 'danger',
text: t("error_occurred", { ns: "chat" }), text: t('error_occurred', { ns: 'chat' }),
}); });
} finally { } finally {
setGeneratingAnswer(false); setGeneratingAnswer(false);

View File

@ -1,12 +1,12 @@
/* eslint-disable max-lines */ /* eslint-disable max-lines */
import { useTranslation } from "react-i18next"; import { useTranslation } from 'react-i18next';
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext';
import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"; import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext';
import { useFetch } from "@/lib/hooks"; import { useFetch } from '@/lib/hooks';
import { ChatHistory, ChatQuestion } from "../types"; import { ChatHistory, ChatQuestion } from '../types';
interface UseChatService { interface UseChatService {
addStreamQuestion: ( addStreamQuestion: (
@ -20,12 +20,12 @@ export const useQuestion = (): UseChatService => {
const { updateStreamingHistory } = useChatContext(); const { updateStreamingHistory } = useChatContext();
const { currentBrain } = useBrainContext(); const { currentBrain } = useBrainContext();
const { t } = useTranslation(["chat"]); const { t } = useTranslation(['chat']);
const handleStream = async ( const handleStream = async (
reader: ReadableStreamDefaultReader<Uint8Array> reader: ReadableStreamDefaultReader<Uint8Array>
): Promise<void> => { ): Promise<void> => {
const decoder = new TextDecoder("utf-8"); const decoder = new TextDecoder('utf-8');
const handleStreamRecursively = async () => { const handleStreamRecursively = async () => {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
@ -37,7 +37,7 @@ export const useQuestion = (): UseChatService => {
const dataStrings = decoder const dataStrings = decoder
.decode(value) .decode(value)
.trim() .trim()
.split("data: ") .split('data: ')
.filter(Boolean); .filter(Boolean);
dataStrings.forEach((data) => { dataStrings.forEach((data) => {
@ -45,7 +45,7 @@ export const useQuestion = (): UseChatService => {
const parsedData = JSON.parse(data) as ChatHistory; const parsedData = JSON.parse(data) as ChatHistory;
updateStreamingHistory(parsedData); updateStreamingHistory(parsedData);
} catch (error) { } catch (error) {
console.error(t("errorParsingData", { ns: "chat" }), error); console.error(t('errorParsingData', { ns: 'chat' }), error);
} }
}); });
@ -59,30 +59,27 @@ export const useQuestion = (): UseChatService => {
chatId: string, chatId: string,
chatQuestion: ChatQuestion chatQuestion: ChatQuestion
): Promise<void> => { ): Promise<void> => {
if (currentBrain?.id === undefined) {
throw new Error(t("noCurrentBrain", { ns: "chat" }));
}
const headers = { const headers = {
"Content-Type": "application/json", 'Content-Type': 'application/json',
Accept: "text/event-stream", Accept: 'text/event-stream',
}; };
const body = JSON.stringify(chatQuestion); const body = JSON.stringify(chatQuestion);
console.log("Calling API..."); console.log('Calling API...');
try { try {
const response = await fetchInstance.post( const response = await fetchInstance.post(
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`, `/chat/${chatId}/question/stream?brain_id=${currentBrain?.id ?? ''}`,
body, body,
headers headers
); );
if (response.body === null) { if (response.body === null) {
throw new Error(t("resposeBodyNull", { ns: "chat" })); throw new Error(t('resposeBodyNull', { ns: 'chat' }));
} }
console.log(t("receivedResponse"), response); console.log(t('receivedResponse'), response);
await handleStream(response.body.getReader()); await handleStream(response.body.getReader());
} catch (error) { } catch (error) {
console.error(t("errorCallingAPI", { ns: "chat" }), error); console.error(t('errorCallingAPI', { ns: 'chat' }), error);
} }
}; };