From 6f047f4a39f504102395ca019dd15dfaa7bd9d50 Mon Sep 17 00:00:00 2001 From: Matt <77928207+mattzcarey@users.noreply.github.com> Date: Fri, 30 Jun 2023 09:10:59 +0100 Subject: [PATCH] feat: streaming for standard brain picking (#385) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: streaming for standard brain picking * fix(bug): private llm * wip: test Co-authored-by: Mamadou DICKO * wip: almost good Co-authored-by: Mamadou DICKO * feat: useFetch * chore: remove 💀 * chore: fix linting * fix: forward the request if not streaming * feat: streaming for standard brain picking * fix(bug): private llm * wip: test Co-authored-by: Mamadou DICKO * wip: almost good Co-authored-by: Mamadou DICKO * feat: useFetch * chore: remove 💀 * chore: fix linting * fix: forward the request if not streaming * fix: 💀 code * fix: check_user_limit * feat: brain_id to new chat stream * fix: missing imports * feat: message_id created on backend Co-authored-by: Mamadou DICKO * chore: remove dead * remove: cpython * remove: dead --------- Co-authored-by: Mamadou DICKO --- backend/llm/PrivateBrainPicking.py | 20 +- backend/llm/brainpicking.py | 189 ++++++++++++------ backend/main.py | 2 - backend/models/chat.py | 5 +- backend/repository/chat/update_chat.py | 7 +- .../repository/chat/update_chat_history.py | 13 +- .../repository/chat/update_message_by_id.py | 38 ++++ backend/routes/chat_routes.py | 112 ++++++++--- backend/routes/stream_routes.py | 121 ----------- backend/utils/common.py | 8 - backend/utils/constants.py | 8 + .../app/chat/[chatId]/context/ChatContext.tsx | 40 ++++ frontend/app/chat/[chatId]/hooks/useChat.ts | 38 +++- .../app/chat/[chatId]/hooks/useChatService.ts | 136 +++++++++++-- frontend/lib/helpers/uuid.ts | 13 ++ frontend/lib/hooks/index.ts | 1 + frontend/lib/hooks/useFetch.ts | 72 +++++++ 17 files changed, 553 insertions(+), 270 deletions(-) create mode 100644 backend/repository/chat/update_message_by_id.py delete mode 100644 backend/routes/stream_routes.py delete mode 100644 backend/utils/common.py create mode 100644 backend/utils/constants.py create mode 100644 frontend/lib/helpers/uuid.ts create mode 100644 frontend/lib/hooks/useFetch.ts diff --git a/backend/llm/PrivateBrainPicking.py b/backend/llm/PrivateBrainPicking.py index c578de5a0..9eb6dadbf 100644 --- a/backend/llm/PrivateBrainPicking.py +++ b/backend/llm/PrivateBrainPicking.py @@ -1,9 +1,9 @@ - # Importing various modules and classes from a custom library 'langchain' likely used for natural language processing from langchain.llms import GPT4All from langchain.llms.base import LLM from llm.brainpicking import BrainPicking from logger import get_logger +from models.settings import LLMSettings logger = get_logger(__name__) @@ -13,6 +13,9 @@ class PrivateBrainPicking(BrainPicking): This subclass of BrainPicking is used to specifically work with a private language model. """ + # Initialize class settings + llm_settings = LLMSettings() + def __init__( self, model: str, @@ -28,7 +31,7 @@ class PrivateBrainPicking(BrainPicking): :param brain_id: The user id to be used for CustomSupabaseVectorStore. :return: PrivateBrainPicking instance """ - # Call the parent class's initializer + super().__init__( model=model, brain_id=brain_id, @@ -38,20 +41,17 @@ class PrivateBrainPicking(BrainPicking): user_openai_api_key=user_openai_api_key, ) - def _determine_llm( - self, private_model_args: dict, private: bool = True, model_name: str = None - ) -> LLM: + def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM: """ - Override the _determine_llm method to enforce the use of a private model. + Override the _create_llm method to enforce the use of a private model. :param model_name: Language model name to be used. :param private_model_args: Dictionary containing model_path, n_ctx and n_batch. :param private: Boolean value to determine if private model is to be used. Defaulted to True. :return: Language model instance """ - # Force the use of a private model by setting private to True. - model_path = private_model_args["model_path"] - model_n_ctx = private_model_args["n_ctx"] - model_n_batch = private_model_args["n_batch"] + model_path = self.llm_settings.model_path + model_n_ctx = self.llm_settings.model_n_ctx + model_n_batch = self.llm_settings.model_n_batch logger.info("Using private model: %s", model_path) diff --git a/backend/llm/brainpicking.py b/backend/llm/brainpicking.py index da4a4bfb4..3ccbdff2e 100644 --- a/backend/llm/brainpicking.py +++ b/backend/llm/brainpicking.py @@ -1,4 +1,8 @@ -from typing import Any, Dict +import asyncio +import json +from typing import AsyncIterable, Awaitable + +from langchain.callbacks import AsyncIteratorCallbackHandler # Importing various modules and classes from a custom library 'langchain' likely used for natural language processing from langchain.chains import ConversationalRetrievalChain, LLMChain @@ -6,69 +10,53 @@ from langchain.chains.question_answering import load_qa_chain from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms.base import LLM -from langchain.memory import ConversationBufferMemory from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT from logger import get_logger -from models.settings import \ - BrainSettings # Importing settings related to the 'brain' -from models.settings import LLMSettings # For type hinting +from models.settings import BrainSettings # Importing settings related to the 'brain' from pydantic import BaseModel # For data validation and settings management from repository.chat.get_chat_history import get_chat_history -from vectorstore.supabase import \ - CustomSupabaseVectorStore # Custom class for handling vector storage with Supabase - +from repository.chat.update_chat_history import update_chat_history +from repository.chat.update_message_by_id import update_message_by_id from supabase import Client # For interacting with Supabase database from supabase import create_client +from vectorstore.supabase import ( + CustomSupabaseVectorStore, +) # Custom class for handling vector storage with Supabase logger = get_logger(__name__) -class AnswerConversationBufferMemory(ConversationBufferMemory): - """ - This class is a specialized version of ConversationBufferMemory. - It overrides the save_context method to save the response using the 'answer' key in the outputs. - Reference to some issue comment is given in the docstring. - """ - - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - # Overriding the save_context method of the parent class - return super(AnswerConversationBufferMemory, self).save_context( - inputs, {"response": outputs["answer"]} - ) - - -def format_chat_history(inputs) -> str: - """ - Function to concatenate chat history into a single string. - :param inputs: List of tuples containing human and AI messages. - :return: concatenated string of chat history - """ - res = [] - for human, ai in inputs: - res.append(f"{human}:{ai}\n") - return "\n".join(res) - - class BrainPicking(BaseModel): """ Main class for the Brain Picking functionality. It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain. """ + # Instantiate settings + settings = BrainSettings() + # Default class attributes llm_name: str = "gpt-3.5-turbo" temperature: float = 0.0 - settings = BrainSettings() - llm_config = LLMSettings() - embeddings: OpenAIEmbeddings = None - supabase_client: Client = None - vector_store: CustomSupabaseVectorStore = None - llm: LLM = None - question_generator: LLMChain = None - doc_chain: ConversationalRetrievalChain = None chat_id: str max_tokens: int = 256 + # Storage + supabase_client: Client = None + vector_store: CustomSupabaseVectorStore = None + + # Language models + embeddings: OpenAIEmbeddings = None + question_llm: LLM = None + doc_llm: LLM = None + question_generator: LLMChain = None + doc_chain: LLMChain = None + qa: ConversationalRetrievalChain = None + + # Streaming + callback: AsyncIteratorCallbackHandler = None + streaming: bool = False + class Config: # Allowing arbitrary types for class validation arbitrary_types_allowed = True @@ -81,6 +69,7 @@ class BrainPicking(BaseModel): chat_id: str, max_tokens: int, user_openai_api_key: str, + streaming: bool = False, ) -> "BrainPicking": """ Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains. @@ -113,25 +102,38 @@ class BrainPicking(BaseModel): brain_id=brain_id, ) - self.llm = self._determine_llm( - private_model_args={ - "model_path": self.llm_config.model_path, - "n_ctx": self.llm_config.model_n_ctx, - "n_batch": self.llm_config.model_n_batch, - }, - private=self.llm_config.private, + self.question_llm = self._create_llm( model_name=self.llm_name, + streaming=False, ) self.question_generator = LLMChain( - llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT + llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT ) - self.doc_chain = load_qa_chain(self.llm, chain_type="stuff") + + if streaming: + self.callback = AsyncIteratorCallbackHandler() + self.doc_llm = self._create_llm( + model_name=self.llm_name, + streaming=streaming, + callbacks=[self.callback], + ) + self.doc_chain = load_qa_chain( + llm=self.doc_llm, + chain_type="stuff", + ) + self.streaming = streaming + else: + self.doc_llm = self._create_llm( + model_name=self.llm_name, + streaming=streaming, + ) + self.doc_chain = load_qa_chain(llm=self.doc_llm, chain_type="stuff") + self.streaming = streaming + self.chat_id = chat_id self.max_tokens = max_tokens - def _determine_llm( - self, private_model_args: dict, private: bool = False, model_name: str = None - ) -> LLM: + def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM: """ Determine the language model to be used. :param model_name: Language model name to be used. @@ -139,8 +141,12 @@ class BrainPicking(BaseModel): :param private: Boolean value to determine if private model is to be used. :return: Language model instance """ - - return ChatOpenAI(temperature=0, model_name=model_name) + return ChatOpenAI( + temperature=0, + model_name=model_name, + streaming=streaming, + callbacks=callbacks, + ) def _get_qa( self, @@ -155,11 +161,11 @@ class BrainPicking(BaseModel): # Initialize and return a ConversationalRetrievalChain qa = ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), - max_tokens_limit=self.max_tokens, question_generator=self.question_generator, combine_docs_chain=self.doc_chain, - get_chat_history=format_chat_history, + verbose=True, ) + return qa def generate_answer(self, question: str) -> str: @@ -182,3 +188,70 @@ class BrainPicking(BaseModel): answer = model_response["answer"] return answer + + async def generate_stream(self, question: str) -> AsyncIterable: + """ + Generate a streaming answer to a given question by interacting with the language model. + :param question: The question + :return: An async iterable which generates the answer. + """ + + # Get the QA chain + qa = self._get_qa() + history = get_chat_history(self.chat_id) + callback = self.callback + + # # Format the chat history into a list of tuples (human, ai) + transformed_history = [(chat.user_message, chat.assistant) for chat in history] + + # Initialize a list to hold the tokens + response_tokens = [] + + # Wrap an awaitable with a event to signal when it's done or an exception is raised. + async def wrap_done(fn: Awaitable, event: asyncio.Event): + try: + await fn + except Exception as e: + logger.error(f"Caught exception: {e}") + finally: + event.set() + + # Use the acall method to perform an async call to the QA chain + task = asyncio.create_task( + wrap_done( + qa.acall( + { + "question": question, + "chat_history": transformed_history, + } + ), + callback.done, + ) + ) + + streamed_chat_history = update_chat_history( + chat_id=self.chat_id, + user_message=question, + assistant="", + ) + + # Use the aiter method of the callback to stream the response with server-sent-events + async for token in callback.aiter(): + logger.info("Token: %s", token) + + # Add the token to the response_tokens list + response_tokens.append(token) + streamed_chat_history.assistant = token + + yield f"data: {json.dumps(streamed_chat_history.to_dict())}" + + await task + + # Join the tokens to create the assistant's response + assistant = "".join(response_tokens) + + update_message_by_id( + message_id=streamed_chat_history.message_id, + user_message=question, + assistant=assistant, + ) diff --git a/backend/main.py b/backend/main.py index ad890dfc0..791f43932 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,7 +10,6 @@ from routes.chat_routes import chat_router from routes.crawl_routes import crawl_router from routes.explore_routes import explore_router from routes.misc_routes import misc_router -from routes.stream_routes import stream_router from routes.upload_routes import upload_router from routes.user_routes import user_router @@ -35,7 +34,6 @@ app.include_router(misc_router) app.include_router(upload_router) app.include_router(user_router) app.include_router(api_key_router) -app.include_router(stream_router) @app.exception_handler(HTTPException) diff --git a/backend/models/chat.py b/backend/models/chat.py index 9d8da514b..266945e92 100644 --- a/backend/models/chat.py +++ b/backend/models/chat.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass @dataclass @@ -29,3 +29,6 @@ class ChatHistory: self.user_message = chat_dict.get("user_message") self.assistant = chat_dict.get("assistant") self.message_time = chat_dict.get("message_time") + + def to_dict(self): + return asdict(self) diff --git a/backend/repository/chat/update_chat.py b/backend/repository/chat/update_chat.py index fea0cf55b..e866333d9 100644 --- a/backend/repository/chat/update_chat.py +++ b/backend/repository/chat/update_chat.py @@ -1,11 +1,10 @@ +from dataclasses import dataclass +from typing import Optional + from logger import get_logger from models.chat import Chat -from typing import Optional -from dataclasses import dataclass - from models.settings import common_dependencies - logger = get_logger(__name__) diff --git a/backend/repository/chat/update_chat_history.py b/backend/repository/chat/update_chat_history.py index 0cfd1d86c..5680bc53c 100644 --- a/backend/repository/chat/update_chat_history.py +++ b/backend/repository/chat/update_chat_history.py @@ -1,12 +1,11 @@ +from typing import List # For type hinting + +from fastapi import HTTPException from models.chat import ChatHistory from models.settings import common_dependencies -from typing import List # For type hinting -from fastapi import HTTPException -def update_chat_history( - chat_id: str, user_message: str, assistant_answer: str -) -> ChatHistory: +def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory: commons = common_dependencies() response: List[ChatHistory] = ( commons["supabase"] @@ -15,7 +14,7 @@ def update_chat_history( { "chat_id": str(chat_id), "user_message": user_message, - "assistant": assistant_answer, + "assistant": assistant, } ) .execute() @@ -24,4 +23,4 @@ def update_chat_history( raise HTTPException( status_code=500, detail="An exception occurred while updating chat history." ) - return response[0] + return ChatHistory(response[0]) diff --git a/backend/repository/chat/update_message_by_id.py b/backend/repository/chat/update_message_by_id.py new file mode 100644 index 000000000..3265d86fd --- /dev/null +++ b/backend/repository/chat/update_message_by_id.py @@ -0,0 +1,38 @@ +from logger import get_logger +from models.chat import ChatHistory +from models.settings import common_dependencies + +logger = get_logger(__name__) + + +def update_message_by_id( + message_id: str, user_message: str, assistant: str +) -> ChatHistory: + commons = common_dependencies() + + if not message_id: + logger.error("No message_id provided") + return + + updates = {} + + if user_message is not None: + updates["user_message"] = user_message + + if assistant is not None: + updates["assistant"] = user_message + + updated_message = None + + if updates: + updated_message = ( + commons["supabase"] + .table("chat_history") + .update(updates) + .match({"message_id": message_id}) + .execute() + ).data[0] + logger.info(f"Message {message_id} updated") + else: + logger.info(f"No updates to apply for message {message_id}") + return ChatHistory(updated_message) diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index 9506de56b..381aee3b2 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -5,7 +5,8 @@ from typing import List from uuid import UUID from auth.auth_bearer import AuthBearer, get_current_user -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import StreamingResponse from llm.brainpicking import BrainPicking from llm.BrainPickingOpenAIFunctions.BrainPickingOpenAIFunctions import ( BrainPickingOpenAIFunctions, @@ -21,6 +22,10 @@ from repository.chat.get_chat_history import get_chat_history from repository.chat.get_user_chats import get_user_chats from repository.chat.update_chat import ChatUpdatableProperties, update_chat from repository.chat.update_chat_history import update_chat_history +from utils.constants import ( + openai_function_compatible_models, + streaming_compatible_models, +) chat_router = APIRouter() @@ -40,6 +45,36 @@ def delete_chat_from_db(commons, chat_id): commons["supabase"].table("chats").delete().match({"chat_id": chat_id}).execute() +def fetch_user_stats(commons, user, date): + response = ( + commons["supabase"] + .from_("users") + .select("*") + .filter("email", "eq", user.email) + .filter("date", "eq", date) + .execute() + ) + userItem = next(iter(response.data or []), {"requests_count": 0}) + return userItem + + +def check_user_limit( + user: User, +): + if user.user_openai_api_key is None: + date = time.strftime("%Y%m%d") + max_requests_number = os.getenv("MAX_REQUESTS_NUMBER") + + user.increment_user_request_count(date) + if user.requests_count >= float(max_requests_number): + raise HTTPException( + status_code=429, + detail="You have reached the maximum number of requests for today.", + ) + else: + pass + + # get all chats @chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"]) async def get_chats(current_user: User = Depends(get_current_user)): @@ -52,7 +87,6 @@ async def get_chats(current_user: User = Depends(get_current_user)): This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects containing the chat ID and chat name for each chat. """ - commons = common_dependencies() chats = get_user_chats(current_user.id) return {"chats": chats} @@ -82,7 +116,6 @@ async def update_chat_metadata_handler( """ Update chat attributes """ - commons = common_dependencies() chat = get_chat_by_id(chat_id) if current_user.id != chat.user_id: @@ -92,24 +125,6 @@ async def update_chat_metadata_handler( return update_chat(chat_id=chat_id, chat_data=chat_data) -# helper method for update and create chat -def check_user_limit( - user: User, -): - if user.user_openai_api_key is None: - date = time.strftime("%Y%m%d") - max_requests_number = os.getenv("MAX_REQUESTS_NUMBER") - - user.increment_user_request_count(date) - if user.requests_count >= float(max_requests_number): - raise HTTPException( - status_code=429, - detail="You have reached the maximum number of requests for today.", - ) - else: - pass - - # create new chat @chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"]) async def create_chat_handler( @@ -139,10 +154,7 @@ async def create_question_handler( try: check_user_limit(current_user) llm_settings = LLMSettings() - openai_function_compatible_models = [ - "gpt-3.5-turbo-0613", - "gpt-4-0613", - ] + if llm_settings.private: gpt_answer_generator = PrivateBrainPicking( model=chat_question.model, @@ -153,6 +165,7 @@ async def create_question_handler( user_openai_api_key=current_user.user_openai_api_key, ) answer = gpt_answer_generator.generate_answer(chat_question.question) + elif chat_question.model in openai_function_compatible_models: # TODO: RBAC with current_user gpt_answer_generator = BrainPickingOpenAIFunctions( @@ -165,6 +178,7 @@ async def create_question_handler( user_openai_api_key=current_user.user_openai_api_key, ) answer = gpt_answer_generator.generate_answer(chat_question.question) + else: brainPicking = BrainPicking( chat_id=str(chat_id), @@ -174,18 +188,64 @@ async def create_question_handler( brain_id=brain_id, user_openai_api_key=current_user.user_openai_api_key, ) + answer = brainPicking.generate_answer(chat_question.question) chat_answer = update_chat_history( chat_id=chat_id, user_message=chat_question.question, - assistant_answer=answer, + assistant=answer, ) return chat_answer except HTTPException as e: raise e +# stream new question response from chat +@chat_router.post( + "/chat/{chat_id}/question/stream", + dependencies=[Depends(AuthBearer())], + tags=["Chat"], +) +async def create_stream_question_handler( + request: Request, + chat_question: ChatQuestion, + chat_id: UUID, + brain_id: UUID = Query(..., description="The ID of the brain"), + current_user: User = Depends(get_current_user), +) -> StreamingResponse: + if ( + os.getenv("PRIVATE") == "True" + or chat_question.model not in streaming_compatible_models + ): + # forward the request to the none streaming endpoint create_question_handler function + return await create_question_handler( + request, chat_question, chat_id, current_user + ) + + try: + user_openai_api_key = request.headers.get("Openai-Api-Key") + check_user_limit(current_user) + + brain = BrainPicking( + chat_id=str(chat_id), + model=chat_question.model, + max_tokens=chat_question.max_tokens, + temperature=chat_question.temperature, + brain_id=brain_id, + user_openai_api_key=user_openai_api_key, + streaming=True, + ) + + return StreamingResponse( + brain.generate_stream(chat_question.question), + media_type="text/event-stream", + ) + + except HTTPException as e: + raise e + + # get chat history @chat_router.get( "/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"] diff --git a/backend/routes/stream_routes.py b/backend/routes/stream_routes.py deleted file mode 100644 index c87f2dab4..000000000 --- a/backend/routes/stream_routes.py +++ /dev/null @@ -1,121 +0,0 @@ -import asyncio -import os -from typing import AsyncIterable, Awaitable -from uuid import UUID - -from auth.auth_bearer import AuthBearer -from fastapi import APIRouter, Depends, Query -from fastapi.responses import StreamingResponse -from langchain.callbacks import AsyncIteratorCallbackHandler -from langchain.chains import ConversationalRetrievalChain -from langchain.chains.llm import LLMChain -from langchain.chains.question_answering import load_qa_chain -from langchain.chat_models import ChatOpenAI -from langchain.embeddings import OpenAIEmbeddings -from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT -from logger import get_logger -from models.chats import ChatMessage -from models.settings import CommonsDep, common_dependencies -from vectorstore.supabase import CustomSupabaseVectorStore - -from supabase import create_client - -logger = get_logger(__name__) - -stream_router = APIRouter() - -openai_api_key = os.getenv("OPENAI_API_KEY") -supabase_url = os.getenv("SUPABASE_URL") -supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY") - - -async def send_message( - chat_message: ChatMessage, chain, callback -) -> AsyncIterable[str]: - async def wrap_done(fn: Awaitable, event: asyncio.Event): - """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" - try: - resp = await fn - logger.debug("Done: %s", resp) - except Exception as e: - logger.error(f"Caught exception: {e}") - finally: - # Signal the aiter to stop. - event.set() - - # Use the agenerate method for models. - # Use the acall method for chains. - task = asyncio.create_task( - wrap_done( - chain.acall( - { - "question": chat_message.question, - "chat_history": chat_message.history, - } - ), - callback.done, - ) - ) - - # Use the aiter method of the callback to stream the response with server-sent-events - async for token in callback.aiter(): - logger.info("Token: %s", token) - yield f"data: {token}\n\n" - - await task - - -def create_chain(commons: CommonsDep, brain_id: UUID): - embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) - - supabase_client = create_client(supabase_url, supabase_service_key) - - vector_store = CustomSupabaseVectorStore( - supabase_client, embeddings, table_name="vectors", brain_id=brain_id - ) - - generator_llm = ChatOpenAI( - temperature=0, - ) - - # Callback provides the on_llm_new_token method - callback = AsyncIteratorCallbackHandler() - - streaming_llm = ChatOpenAI( - temperature=0, - streaming=True, - callbacks=[callback], - ) - question_generator = LLMChain( - llm=generator_llm, - prompt=CONDENSE_QUESTION_PROMPT, - ) - doc_chain = load_qa_chain( - llm=streaming_llm, - chain_type="stuff", - ) - - return ( - ConversationalRetrievalChain( - combine_docs_chain=doc_chain, - question_generator=question_generator, - retriever=vector_store.as_retriever(), - verbose=True, - ), - callback, - ) - - -@stream_router.post("/stream", dependencies=[Depends(AuthBearer())], tags=["Stream"]) -async def stream( - chat_message: ChatMessage, - brain_id: UUID = Query(..., description="The ID of the brain"), -) -> StreamingResponse: - commons = common_dependencies() - - qa_chain, callback = create_chain(commons, brain_id) - - return StreamingResponse( - send_message(chat_message, qa_chain, callback), - media_type="text/event-stream", - ) diff --git a/backend/utils/common.py b/backend/utils/common.py deleted file mode 100644 index 53fa9e17c..000000000 --- a/backend/utils/common.py +++ /dev/null @@ -1,8 +0,0 @@ -import os -from typing import Annotated - -from fastapi import Depends -from logger import get_logger -from models.settings import common_dependencies - -logger = get_logger(__name__) diff --git a/backend/utils/constants.py b/backend/utils/constants.py new file mode 100644 index 000000000..1e5744c9b --- /dev/null +++ b/backend/utils/constants.py @@ -0,0 +1,8 @@ +openai_function_compatible_models = [ + "gpt-3.5-turbo-0613", + "gpt-4-0613", +] + +streaming_compatible_models = ["gpt-3.5-turbo"] + +private_models = ["gpt4all-j-1.3"] diff --git a/frontend/app/chat/[chatId]/context/ChatContext.tsx b/frontend/app/chat/[chatId]/context/ChatContext.tsx index 4eb24a57b..8d69dceb9 100644 --- a/frontend/app/chat/[chatId]/context/ChatContext.tsx +++ b/frontend/app/chat/[chatId]/context/ChatContext.tsx @@ -8,6 +8,8 @@ type ChatContextProps = { history: ChatHistory[]; setHistory: (history: ChatHistory[]) => void; addToHistory: (message: ChatHistory) => void; + updateHistory: (chat: ChatHistory) => void; + updateStreamingHistory: (streamedChat: ChatHistory) => void; }; export const ChatContext = createContext( @@ -20,16 +22,54 @@ export const ChatProvider = ({ children: JSX.Element | JSX.Element[]; }): JSX.Element => { const [history, setHistory] = useState([]); + const addToHistory = (message: ChatHistory) => { setHistory((prevHistory) => [...prevHistory, message]); }; + const updateStreamingHistory = (streamedChat: ChatHistory): void => { + setHistory((prevHistory: ChatHistory[]) => { + console.log("new chat", streamedChat); + const updatedHistory = prevHistory.find( + (item) => item.message_id === streamedChat.message_id + ) + ? prevHistory.map((item: ChatHistory) => + item.message_id === streamedChat.message_id + ? { ...item, assistant: item.assistant + streamedChat.assistant } + : item + ) + : [...prevHistory, streamedChat]; + + console.log("updated history", updatedHistory); + + return updatedHistory; + }); + }; + + const updateHistory = (chat: ChatHistory): void => { + setHistory((prevHistory: ChatHistory[]) => { + const updatedHistory = prevHistory.find( + (item) => item.message_id === chat.message_id + ) + ? prevHistory.map((item: ChatHistory) => + item.message_id === chat.message_id + ? { ...item, assistant: chat.assistant } + : item + ) + : [...prevHistory, chat]; + + return updatedHistory; + }); + }; + return ( {children} diff --git a/frontend/app/chat/[chatId]/hooks/useChat.ts b/frontend/app/chat/[chatId]/hooks/useChat.ts index d565f4f22..1229e7888 100644 --- a/frontend/app/chat/[chatId]/hooks/useChat.ts +++ b/frontend/app/chat/[chatId]/hooks/useChat.ts @@ -1,3 +1,4 @@ +/* eslint-disable max-lines */ import { AxiosError } from "axios"; import { useParams } from "next/navigation"; import { useEffect, useState } from "react"; @@ -21,30 +22,34 @@ export const useChat = () => { const { config: { maxTokens, model, temperature }, } = useBrainConfig(); - const { history, setHistory, addToHistory } = useChatContext(); + const { history, setHistory } = useChatContext(); const { publish } = useToast(); const { createChat, getChatHistory, - addQuestion: addQuestionToChat, + addStreamQuestion, + addQuestion: addQuestionToModel, } = useChatService(); useEffect(() => { const fetchHistory = async () => { - const chatHistory = await getChatHistory(chatId); - setHistory(chatHistory); + const currentChatId = chatId; + const chatHistory = await getChatHistory(currentChatId); + + if (chatId === currentChatId && chatHistory.length > 0) { + setHistory(chatHistory); + } }; void fetchHistory(); - }, [chatId]); + }, [chatId, getChatHistory, setHistory]); const generateNewChatIdFromName = async ( chatName: string ): Promise => { - const rep = await createChat({ name: chatName }); - setChatId(rep.data.chat_id); + const chat = await createChat({ name: chatName }); - return rep.data.chat_id; + return chat.chat_id; }; const addQuestion = async (question: string, callback?: () => void) => { @@ -64,8 +69,15 @@ export const useChat = () => { (await generateNewChatIdFromName( question.split(" ").slice(0, 3).join(" ") )); - const answer = await addQuestionToChat(currentChatId, chatQuestion); - addToHistory(answer); + + setChatId(currentChatId); + + if (chatQuestion.model === "gpt-3.5-turbo") { + await addStreamQuestion(currentChatId, chatQuestion); + } else { + await addQuestionToModel(currentChatId, chatQuestion); + } + callback?.(); } catch (error) { console.error({ error }); @@ -88,5 +100,9 @@ export const useChat = () => { } }; - return { history, addQuestion, generatingAnswer }; + return { + history, + addQuestion, + generatingAnswer, + }; }; diff --git a/frontend/app/chat/[chatId]/hooks/useChatService.ts b/frontend/app/chat/[chatId]/hooks/useChatService.ts index cbfd3dfa9..472e7b021 100644 --- a/frontend/app/chat/[chatId]/hooks/useChatService.ts +++ b/frontend/app/chat/[chatId]/hooks/useChatService.ts @@ -1,45 +1,137 @@ -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useAxios } from "@/lib/hooks"; +/* eslint-disable max-lines */ +import { useCallback } from "react"; + +import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; +import { useAxios, useFetch } from "@/lib/hooks"; + +import { useChatContext } from "../context/ChatContext"; import { ChatEntity, ChatHistory, ChatQuestion } from "../types"; -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useChatService = () => { +interface UseChatService { + createChat: (name: { name: string }) => Promise; + getChatHistory: (chatId: string | undefined) => Promise; + addQuestion: (chatId: string, chatQuestion: ChatQuestion) => Promise; + addStreamQuestion: ( + chatId: string, + chatQuestion: ChatQuestion + ) => Promise; +} + +export const useChatService = (): UseChatService => { const { axiosInstance } = useAxios(); + const { fetchInstance } = useFetch(); + const { updateHistory, updateStreamingHistory } = useChatContext(); const { currentBrain } = useBrainContext(); - const createChat = async ({ name }: { name: string }) => { - return axiosInstance.post(`/chat`, { name }); + const createChat = async ({ + name, + }: { + name: string; + }): Promise => { + const response = (await axiosInstance.post(`/chat`, { name })) + .data; + + return response; }; - const getChatHistory = async (chatId: string | undefined) => { - if (chatId === undefined) { - return []; - } - const rep = ( - await axiosInstance.get(`/chat/${chatId}/history`) - ).data; + const getChatHistory = useCallback( + async (chatId: string | undefined): Promise => { + if (chatId === undefined) { + return []; + } + const response = ( + await axiosInstance.get(`/chat/${chatId}/history`) + ).data; + + return response; + }, + [axiosInstance] + ); - return rep; - }; const addQuestion = async ( chatId: string, chatQuestion: ChatQuestion - ): Promise => { + ): Promise => { if (currentBrain?.id === undefined) { throw new Error("No current brain"); } - return ( - await axiosInstance.post( - `/chat/${chatId}/question?brain_id=${currentBrain.id}`, - chatQuestion - ) - ).data; + const response = await axiosInstance.post( + `/chat/${chatId}/question?brain_id=${currentBrain.id}`, + chatQuestion + ); + + updateHistory(response.data); + }; + + const handleStream = async ( + reader: ReadableStreamDefaultReader + ): Promise => { + const decoder = new TextDecoder("utf-8"); + + const handleStreamRecursively = async () => { + const { done, value } = await reader.read(); + + if (done) { + return; + } + + const dataStrings = decoder + .decode(value) + .trim() + .split("data: ") + .filter(Boolean); + + dataStrings.forEach((data) => { + try { + const parsedData = JSON.parse(data) as ChatHistory; + updateStreamingHistory(parsedData); + } catch (error) { + console.error("Error parsing data:", error); + } + }); + + await handleStreamRecursively(); + }; + + await handleStreamRecursively(); + }; + + const addStreamQuestion = async ( + chatId: string, + chatQuestion: ChatQuestion + ): Promise => { + if (currentBrain?.id === undefined) { + throw new Error("No current brain"); + } + const headers = { + "Content-Type": "application/json", + Accept: "text/event-stream", + }; + const body = JSON.stringify(chatQuestion); + + try { + const response = await fetchInstance.post( + `/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`, + body, + headers + ); + + if (response.body === null) { + throw new Error("Response body is null"); + } + + console.log("Received response. Starting to handle stream..."); + await handleStream(response.body.getReader()); + } catch (error) { + console.error("Error calling the API:", error); + } }; return { createChat, getChatHistory, addQuestion, + addStreamQuestion, }; }; diff --git a/frontend/lib/helpers/uuid.ts b/frontend/lib/helpers/uuid.ts new file mode 100644 index 000000000..b32a6defa --- /dev/null +++ b/frontend/lib/helpers/uuid.ts @@ -0,0 +1,13 @@ +export const generateUUID = (): string => { + const array = new Uint32Array(4); + window.crypto.getRandomValues(array); + let idx = -1; + + return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => { + idx++; + const r = (array[idx >> 3] >> ((idx % 8) * 4)) & 15; + const v = c === "x" ? r : (r & 0x3) | 0x8; + + return v.toString(16); + }); +}; diff --git a/frontend/lib/hooks/index.ts b/frontend/lib/hooks/index.ts index 3c03132a1..38109f24e 100644 --- a/frontend/lib/hooks/index.ts +++ b/frontend/lib/hooks/index.ts @@ -1,2 +1,3 @@ export * from "./useAxios"; +export * from "./useFetch"; export * from "./useToast"; diff --git a/frontend/lib/hooks/useFetch.ts b/frontend/lib/hooks/useFetch.ts new file mode 100644 index 000000000..e2cea39d9 --- /dev/null +++ b/frontend/lib/hooks/useFetch.ts @@ -0,0 +1,72 @@ +import { useEffect, useState } from "react"; + +import { useBrainConfig } from "../context/BrainConfigProvider/hooks/useBrainConfig"; +import { useSupabase } from "../context/SupabaseProvider"; + +interface FetchInstance { + get: (url: string, headers?: HeadersInit) => Promise; + post: ( + url: string, + body: BodyInit | null | undefined, + headers?: HeadersInit + ) => Promise; + put: ( + url: string, + body: BodyInit | null | undefined, + headers?: HeadersInit + ) => Promise; + delete: (url: string, headers?: HeadersInit) => Promise; +} + +const fetchInstance: FetchInstance = { + get: async (url, headers) => fetch(url, { method: "GET", headers }), + post: async (url, body, headers) => + fetch(url, { method: "POST", body, headers }), + put: async (url, body, headers) => + fetch(url, { method: "PUT", body, headers }), + delete: async (url, headers) => fetch(url, { method: "DELETE", headers }), +}; + +export const useFetch = (): { fetchInstance: FetchInstance } => { + const { session } = useSupabase(); + const { + config: { backendUrl: configBackendUrl, openAiKey }, + } = useBrainConfig(); + + const [instance, setInstance] = useState(fetchInstance); + + const baseURL = `${process.env.NEXT_PUBLIC_BACKEND_URL ?? ""}`; + const backendUrl = configBackendUrl ?? baseURL; + + useEffect(() => { + setInstance({ + ...fetchInstance, + get: async (url, headers) => + fetchInstance.get(`${backendUrl}${url}`, { + Authorization: `Bearer ${session?.access_token ?? ""}`, + "Openai-Api-Key": openAiKey ?? "", + ...headers, + }), + post: async (url, body, headers) => + fetchInstance.post(`${backendUrl}${url}`, body, { + Authorization: `Bearer ${session?.access_token ?? ""}`, + "Openai-Api-Key": openAiKey ?? "", + ...headers, + }), + put: async (url, body, headers) => + fetchInstance.put(`${backendUrl}${url}`, body, { + Authorization: `Bearer ${session?.access_token ?? ""}`, + "Openai-Api-Key": openAiKey ?? "", + ...headers, + }), + delete: async (url, headers) => + fetchInstance.delete(`${backendUrl}${url}`, { + Authorization: `Bearer ${session?.access_token ?? ""}`, + "Openai-Api-Key": openAiKey ?? "", + ...headers, + }), + }); + }, [session, backendUrl, openAiKey]); + + return { fetchInstance: instance }; +};