feat(chat): added streaming (#808)

* feat(tmp): added streaming

* feat(streaming): implemented by changing order
This commit is contained in:
Stan Girard 2023-07-31 21:34:34 +02:00 committed by GitHub
parent db40f3cccd
commit 3166d089ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 61 deletions

View File

@ -1,14 +1,12 @@
from abc import abstractmethod
from typing import AsyncIterable, List
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from utils.constants import streaming_compatible_models
logger = get_logger(__name__)
@ -33,7 +31,7 @@ class BaseBrainPicking(BaseModel):
openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
callbacks: List[
AsyncCallbackHandler
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none
def _determine_api_key(self, openai_api_key, user_openai_api_key):
@ -45,23 +43,14 @@ class BaseBrainPicking(BaseModel):
def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
if model in streaming_compatible_models and streaming:
return True
if model not in streaming_compatible_models and streaming:
logger.warning(
f"Streaming is not compatible with {model}. Streaming will be set to False."
)
return False
else:
return False
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler # pyright: ignore reportPrivateUsage=none
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]
def __init__(self, **data):

View File

@ -58,5 +58,6 @@ class OpenAIBrainPicking(QABaseBrainPicking):
temperature=self.temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
) # pyright: ignore reportPrivateUsage=none

View File

@ -1,20 +1,21 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from langchain.chat_models import ChatOpenAI
from repository.chat.update_message_by_id import update_message_by_id
import json
from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
@ -60,13 +61,13 @@ class QABaseBrainPicking(BaseBrainPicking):
@property
def vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)
@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)
@ -74,17 +75,17 @@ class QABaseBrainPicking(BaseBrainPicking):
@property
def doc_llm(self):
return self._create_llm(
model=self.model, streaming=self.streaming, callbacks=self.callbacks
model=self.model, streaming=True, callbacks=self.callbacks
)
@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True)
@property
def doc_chain(self) -> LLMChain:
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff"
llm=self.doc_llm, chain_type="stuff", verbose=True
) # pyright: ignore reportPrivateUsage=none
@property
@ -170,10 +171,20 @@ class QABaseBrainPicking(BaseBrainPicking):
:param question: The question
:return: An async iterable which generates the answer.
"""
history = get_chat_history(self.chat_id)
callback = self.callbacks[0]
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
)
llm = ChatOpenAI(temperature=0)
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(model, chain_type="stuff")
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator)
transformed_history = []
# Format the chat history into a list of tuples (human, ai)
@ -183,6 +194,7 @@ class QABaseBrainPicking(BaseBrainPicking):
response_tokens = []
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
@ -190,16 +202,13 @@ class QABaseBrainPicking(BaseBrainPicking):
logger.error(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
self.qa, question, transformed_history
),
callback.done, # pyright: ignore reportPrivateUsage=none
)
)
# Begin a task that runs in the background.
run = asyncio.create_task(wrap_done(
qa.acall({"question": question, "chat_history": transformed_history}),
callback.done,
))
streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
@ -216,8 +225,7 @@ class QABaseBrainPicking(BaseBrainPicking):
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
await task
await run
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)

View File

@ -3,6 +3,7 @@ import time
from http.client import HTTPException
from typing import List
from uuid import UUID
from venv import logger
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query, Request
@ -18,9 +19,6 @@ from repository.chat.get_chat_by_id import get_chat_by_id
from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
from utils.constants import (
streaming_compatible_models,
)
chat_router = APIRouter()
@ -228,22 +226,14 @@ async def create_stream_question_handler(
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
# TODO: check if the user has access to the brain
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id
if chat_question.model not in streaming_compatible_models:
# Forward the request to the none streaming endpoint
return await create_question_handler(
request,
chat_question,
chat_id,
current_user, # pyright: ignore reportPrivateUsage=none
)
try:
user_openai_api_key = request.headers.get("Openai-Api-Key")
streaming = True
logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user)
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
@ -251,10 +241,11 @@ async def create_stream_question_handler(
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=streaming,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
print("streaming")
return StreamingResponse(
gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none
chat_question.question

View File

@ -24,11 +24,12 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
def similarity_search(
self,
query: str,
table: str = "match_vectors",
k: int = 6,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(

View File

@ -9,8 +9,10 @@ import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"
import { useToast } from "@/lib/hooks";
import { useEventTracking } from "@/services/analytics/useEventTracking";
import { useQuestion } from "./useQuestion";
import { ChatQuestion } from "../types";
import { useQuestion } from "./useQuestion";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChat = () => {
@ -68,11 +70,9 @@ export const useChat = () => {
void track("QUESTION_ASKED");
if (chatQuestion.model === "gpt-3.5-turbo") {
await addStreamQuestion(currentChatId, chatQuestion);
} else {
await addQuestionToModel(currentChatId, chatQuestion);
}
await addStreamQuestion(currentChatId, chatQuestion);
callback?.();
} catch (error) {

View File

@ -79,7 +79,7 @@ export const useQuestion = (): UseChatService => {
Accept: "text/event-stream",
};
const body = JSON.stringify(chatQuestion);
console.log("Calling API...");
try {
const response = await fetchInstance.post(
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,

View File

@ -21,6 +21,7 @@ export type BrainConfigContextType = {
// export const openAiModels = ["gpt-3.5-turbo", "gpt-4"] as const; ## TODO activate GPT4 when not in demo mode
export const openAiModels = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
] as const;