feat: stream router (#353)

* wip: stream router

* feat: chatai streaming

* chore: add comments

* feat: streaming for chains

* chore: comments
This commit is contained in:
Matt 2023-06-20 20:53:04 +01:00 committed by GitHub
parent 90bd49527b
commit 3e753f2d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 128 additions and 1 deletions

View File

@ -10,6 +10,7 @@ from routes.chat_routes import chat_router
from routes.crawl_routes import crawl_router
from routes.explore_routes import explore_router
from routes.misc_routes import misc_router
from routes.stream_routes import stream_router
from routes.upload_routes import upload_router
from routes.user_routes import user_router
@ -19,12 +20,14 @@ app = FastAPI()
add_cors_middleware(app)
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200)
@app.on_event("startup")
async def startup_event():
pypandoc.download_pandoc()
app.include_router(brain_router)
app.include_router(chat_router)
app.include_router(crawl_router)
@ -33,3 +36,4 @@ app.include_router(misc_router)
app.include_router(upload_router)
app.include_router(user_router)
app.include_router(api_key_router)
app.include_router(stream_router)

View File

@ -0,0 +1,123 @@
import asyncio
import os
from typing import AsyncIterable, Awaitable
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from logger import get_logger
from models.chats import ChatMessage
from models.settings import CommonsDep, common_dependencies
from models.users import User
from supabase import create_client
from utils.users import fetch_user_id_from_credentials
from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__)
stream_router = APIRouter()
openai_api_key = os.getenv("OPENAI_API_KEY")
supabase_url = os.getenv("SUPABASE_URL")
supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY")
async def send_message(
chat_message: ChatMessage, chain, callback
) -> AsyncIterable[str]:
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
resp = await fn
logger.debug("Done: %s", resp)
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
# Use the agenerate method for models.
# Use the acall method for chains.
task = asyncio.create_task(
wrap_done(
chain.acall(
{
"question": chat_message.question,
"chat_history": chat_message.history,
}
),
callback.done,
)
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter():
logger.info("Token: %s", token)
yield f"data: {token}\n\n"
await task
def create_chain(commons: CommonsDep, current_user: User):
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
supabase_client = create_client(supabase_url, supabase_service_key)
vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", user_id=user_id
)
generator_llm = ChatOpenAI(
temperature=0,
)
# Callback provides the on_llm_new_token method
callback = AsyncIteratorCallbackHandler()
streaming_llm = ChatOpenAI(
temperature=0,
streaming=True,
callbacks=[callback],
)
question_generator = LLMChain(
llm=generator_llm,
prompt=CONDENSE_QUESTION_PROMPT,
)
doc_chain = load_qa_chain(
llm=streaming_llm,
chain_type="stuff",
)
return (
ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
question_generator=question_generator,
retriever=vector_store.as_retriever(),
verbose=True,
),
callback,
)
@stream_router.post("/stream", dependencies=[Depends(AuthBearer())], tags=["Stream"])
async def stream(
chat_message: ChatMessage,
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
commons = common_dependencies()
qa_chain, callback = create_chain(commons, current_user)
return StreamingResponse(
send_message(chat_message, qa_chain, callback),
media_type="text/event-stream",
)