feat: streaming for standard brain picking (#385)

* feat: streaming for standard brain picking

* fix(bug): private llm

* wip: test

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>

* wip: almost good

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>

* feat: useFetch

* chore: remove 💀

* chore: fix linting

* fix: forward the request if not streaming

* feat: streaming for standard brain picking

* fix(bug): private llm

* wip: test

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>

* wip: almost good

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>

* feat: useFetch

* chore: remove 💀

* chore: fix linting

* fix: forward the request if not streaming

* fix: 💀 code

* fix: check_user_limit

* feat: brain_id to new chat stream

* fix: missing imports

* feat: message_id created on backend

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>

* chore: remove dead

* remove: cpython

* remove: dead

---------

Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>
This commit is contained in:
Matt 2023-06-30 09:10:59 +01:00 committed by GitHub
parent 056a68d5ed
commit 6f047f4a39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 553 additions and 270 deletions

View File

@ -1,9 +1,9 @@
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
from langchain.llms import GPT4All
from langchain.llms.base import LLM
from llm.brainpicking import BrainPicking
from logger import get_logger
from models.settings import LLMSettings
logger = get_logger(__name__)
@ -13,6 +13,9 @@ class PrivateBrainPicking(BrainPicking):
This subclass of BrainPicking is used to specifically work with a private language model.
"""
# Initialize class settings
llm_settings = LLMSettings()
def __init__(
self,
model: str,
@ -28,7 +31,7 @@ class PrivateBrainPicking(BrainPicking):
:param brain_id: The user id to be used for CustomSupabaseVectorStore.
:return: PrivateBrainPicking instance
"""
# Call the parent class's initializer
super().__init__(
model=model,
brain_id=brain_id,
@ -38,20 +41,17 @@ class PrivateBrainPicking(BrainPicking):
user_openai_api_key=user_openai_api_key,
)
def _determine_llm(
self, private_model_args: dict, private: bool = True, model_name: str = None
) -> LLM:
def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM:
"""
Override the _determine_llm method to enforce the use of a private model.
Override the _create_llm method to enforce the use of a private model.
:param model_name: Language model name to be used.
:param private_model_args: Dictionary containing model_path, n_ctx and n_batch.
:param private: Boolean value to determine if private model is to be used. Defaulted to True.
:return: Language model instance
"""
# Force the use of a private model by setting private to True.
model_path = private_model_args["model_path"]
model_n_ctx = private_model_args["n_ctx"]
model_n_batch = private_model_args["n_batch"]
model_path = self.llm_settings.model_path
model_n_ctx = self.llm_settings.model_n_ctx
model_n_batch = self.llm_settings.model_n_batch
logger.info("Using private model: %s", model_path)

View File

@ -1,4 +1,8 @@
from typing import Any, Dict
import asyncio
import json
from typing import AsyncIterable, Awaitable
from langchain.callbacks import AsyncIteratorCallbackHandler
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
from langchain.chains import ConversationalRetrievalChain, LLMChain
@ -6,69 +10,53 @@ from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import LLM
from langchain.memory import ConversationBufferMemory
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from logger import get_logger
from models.settings import \
BrainSettings # Importing settings related to the 'brain'
from models.settings import LLMSettings # For type hinting
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from repository.chat.get_chat_history import get_chat_history
from vectorstore.supabase import \
CustomSupabaseVectorStore # Custom class for handling vector storage with Supabase
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase import Client # For interacting with Supabase database
from supabase import create_client
from vectorstore.supabase import (
CustomSupabaseVectorStore,
) # Custom class for handling vector storage with Supabase
logger = get_logger(__name__)
class AnswerConversationBufferMemory(ConversationBufferMemory):
"""
This class is a specialized version of ConversationBufferMemory.
It overrides the save_context method to save the response using the 'answer' key in the outputs.
Reference to some issue comment is given in the docstring.
"""
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
# Overriding the save_context method of the parent class
return super(AnswerConversationBufferMemory, self).save_context(
inputs, {"response": outputs["answer"]}
)
def format_chat_history(inputs) -> str:
"""
Function to concatenate chat history into a single string.
:param inputs: List of tuples containing human and AI messages.
:return: concatenated string of chat history
"""
res = []
for human, ai in inputs:
res.append(f"{human}:{ai}\n")
return "\n".join(res)
class BrainPicking(BaseModel):
"""
Main class for the Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
"""
# Instantiate settings
settings = BrainSettings()
# Default class attributes
llm_name: str = "gpt-3.5-turbo"
temperature: float = 0.0
settings = BrainSettings()
llm_config = LLMSettings()
embeddings: OpenAIEmbeddings = None
supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
llm: LLM = None
question_generator: LLMChain = None
doc_chain: ConversationalRetrievalChain = None
chat_id: str
max_tokens: int = 256
# Storage
supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
# Language models
embeddings: OpenAIEmbeddings = None
question_llm: LLM = None
doc_llm: LLM = None
question_generator: LLMChain = None
doc_chain: LLMChain = None
qa: ConversationalRetrievalChain = None
# Streaming
callback: AsyncIteratorCallbackHandler = None
streaming: bool = False
class Config:
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
@ -81,6 +69,7 @@ class BrainPicking(BaseModel):
chat_id: str,
max_tokens: int,
user_openai_api_key: str,
streaming: bool = False,
) -> "BrainPicking":
"""
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
@ -113,25 +102,38 @@ class BrainPicking(BaseModel):
brain_id=brain_id,
)
self.llm = self._determine_llm(
private_model_args={
"model_path": self.llm_config.model_path,
"n_ctx": self.llm_config.model_n_ctx,
"n_batch": self.llm_config.model_n_batch,
},
private=self.llm_config.private,
self.question_llm = self._create_llm(
model_name=self.llm_name,
streaming=False,
)
self.question_generator = LLMChain(
llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT
llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT
)
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
if streaming:
self.callback = AsyncIteratorCallbackHandler()
self.doc_llm = self._create_llm(
model_name=self.llm_name,
streaming=streaming,
callbacks=[self.callback],
)
self.doc_chain = load_qa_chain(
llm=self.doc_llm,
chain_type="stuff",
)
self.streaming = streaming
else:
self.doc_llm = self._create_llm(
model_name=self.llm_name,
streaming=streaming,
)
self.doc_chain = load_qa_chain(llm=self.doc_llm, chain_type="stuff")
self.streaming = streaming
self.chat_id = chat_id
self.max_tokens = max_tokens
def _determine_llm(
self, private_model_args: dict, private: bool = False, model_name: str = None
) -> LLM:
def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM:
"""
Determine the language model to be used.
:param model_name: Language model name to be used.
@ -139,8 +141,12 @@ class BrainPicking(BaseModel):
:param private: Boolean value to determine if private model is to be used.
:return: Language model instance
"""
return ChatOpenAI(temperature=0, model_name=model_name)
return ChatOpenAI(
temperature=0,
model_name=model_name,
streaming=streaming,
callbacks=callbacks,
)
def _get_qa(
self,
@ -155,11 +161,11 @@ class BrainPicking(BaseModel):
# Initialize and return a ConversationalRetrievalChain
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
max_tokens_limit=self.max_tokens,
question_generator=self.question_generator,
combine_docs_chain=self.doc_chain,
get_chat_history=format_chat_history,
verbose=True,
)
return qa
def generate_answer(self, question: str) -> str:
@ -182,3 +188,70 @@ class BrainPicking(BaseModel):
answer = model_response["answer"]
return answer
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question by interacting with the language model.
:param question: The question
:return: An async iterable which generates the answer.
"""
# Get the QA chain
qa = self._get_qa()
history = get_chat_history(self.chat_id)
callback = self.callback
# # Format the chat history into a list of tuples (human, ai)
transformed_history = [(chat.user_message, chat.assistant) for chat in history]
# Initialize a list to hold the tokens
response_tokens = []
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
# Use the acall method to perform an async call to the QA chain
task = asyncio.create_task(
wrap_done(
qa.acall(
{
"question": question,
"chat_history": transformed_history,
}
),
callback.done,
)
)
streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant="",
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter():
logger.info("Token: %s", token)
# Add the token to the response_tokens list
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
await task
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)
update_message_by_id(
message_id=streamed_chat_history.message_id,
user_message=question,
assistant=assistant,
)

View File

@ -10,7 +10,6 @@ from routes.chat_routes import chat_router
from routes.crawl_routes import crawl_router
from routes.explore_routes import explore_router
from routes.misc_routes import misc_router
from routes.stream_routes import stream_router
from routes.upload_routes import upload_router
from routes.user_routes import user_router
@ -35,7 +34,6 @@ app.include_router(misc_router)
app.include_router(upload_router)
app.include_router(user_router)
app.include_router(api_key_router)
app.include_router(stream_router)
@app.exception_handler(HTTPException)

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass
@dataclass
@ -29,3 +29,6 @@ class ChatHistory:
self.user_message = chat_dict.get("user_message")
self.assistant = chat_dict.get("assistant")
self.message_time = chat_dict.get("message_time")
def to_dict(self):
return asdict(self)

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import Optional
from logger import get_logger
from models.chat import Chat
from typing import Optional
from dataclasses import dataclass
from models.settings import common_dependencies
logger = get_logger(__name__)

View File

@ -1,12 +1,11 @@
from typing import List # For type hinting
from fastapi import HTTPException
from models.chat import ChatHistory
from models.settings import common_dependencies
from typing import List # For type hinting
from fastapi import HTTPException
def update_chat_history(
chat_id: str, user_message: str, assistant_answer: str
) -> ChatHistory:
def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory:
commons = common_dependencies()
response: List[ChatHistory] = (
commons["supabase"]
@ -15,7 +14,7 @@ def update_chat_history(
{
"chat_id": str(chat_id),
"user_message": user_message,
"assistant": assistant_answer,
"assistant": assistant,
}
)
.execute()
@ -24,4 +23,4 @@ def update_chat_history(
raise HTTPException(
status_code=500, detail="An exception occurred while updating chat history."
)
return response[0]
return ChatHistory(response[0])

View File

@ -0,0 +1,38 @@
from logger import get_logger
from models.chat import ChatHistory
from models.settings import common_dependencies
logger = get_logger(__name__)
def update_message_by_id(
message_id: str, user_message: str, assistant: str
) -> ChatHistory:
commons = common_dependencies()
if not message_id:
logger.error("No message_id provided")
return
updates = {}
if user_message is not None:
updates["user_message"] = user_message
if assistant is not None:
updates["assistant"] = user_message
updated_message = None
if updates:
updated_message = (
commons["supabase"]
.table("chat_history")
.update(updates)
.match({"message_id": message_id})
.execute()
).data[0]
logger.info(f"Message {message_id} updated")
else:
logger.info(f"No updates to apply for message {message_id}")
return ChatHistory(updated_message)

View File

@ -5,7 +5,8 @@ from typing import List
from uuid import UUID
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from llm.brainpicking import BrainPicking
from llm.BrainPickingOpenAIFunctions.BrainPickingOpenAIFunctions import (
BrainPickingOpenAIFunctions,
@ -21,6 +22,10 @@ from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
from repository.chat.update_chat_history import update_chat_history
from utils.constants import (
openai_function_compatible_models,
streaming_compatible_models,
)
chat_router = APIRouter()
@ -40,6 +45,36 @@ def delete_chat_from_db(commons, chat_id):
commons["supabase"].table("chats").delete().match({"chat_id": chat_id}).execute()
def fetch_user_stats(commons, user, date):
response = (
commons["supabase"]
.from_("users")
.select("*")
.filter("email", "eq", user.email)
.filter("date", "eq", date)
.execute()
)
userItem = next(iter(response.data or []), {"requests_count": 0})
return userItem
def check_user_limit(
user: User,
):
if user.user_openai_api_key is None:
date = time.strftime("%Y%m%d")
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
user.increment_user_request_count(date)
if user.requests_count >= float(max_requests_number):
raise HTTPException(
status_code=429,
detail="You have reached the maximum number of requests for today.",
)
else:
pass
# get all chats
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def get_chats(current_user: User = Depends(get_current_user)):
@ -52,7 +87,6 @@ async def get_chats(current_user: User = Depends(get_current_user)):
This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects
containing the chat ID and chat name for each chat.
"""
commons = common_dependencies()
chats = get_user_chats(current_user.id)
return {"chats": chats}
@ -82,7 +116,6 @@ async def update_chat_metadata_handler(
"""
Update chat attributes
"""
commons = common_dependencies()
chat = get_chat_by_id(chat_id)
if current_user.id != chat.user_id:
@ -92,24 +125,6 @@ async def update_chat_metadata_handler(
return update_chat(chat_id=chat_id, chat_data=chat_data)
# helper method for update and create chat
def check_user_limit(
user: User,
):
if user.user_openai_api_key is None:
date = time.strftime("%Y%m%d")
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
user.increment_user_request_count(date)
if user.requests_count >= float(max_requests_number):
raise HTTPException(
status_code=429,
detail="You have reached the maximum number of requests for today.",
)
else:
pass
# create new chat
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def create_chat_handler(
@ -139,10 +154,7 @@ async def create_question_handler(
try:
check_user_limit(current_user)
llm_settings = LLMSettings()
openai_function_compatible_models = [
"gpt-3.5-turbo-0613",
"gpt-4-0613",
]
if llm_settings.private:
gpt_answer_generator = PrivateBrainPicking(
model=chat_question.model,
@ -153,6 +165,7 @@ async def create_question_handler(
user_openai_api_key=current_user.user_openai_api_key,
)
answer = gpt_answer_generator.generate_answer(chat_question.question)
elif chat_question.model in openai_function_compatible_models:
# TODO: RBAC with current_user
gpt_answer_generator = BrainPickingOpenAIFunctions(
@ -165,6 +178,7 @@ async def create_question_handler(
user_openai_api_key=current_user.user_openai_api_key,
)
answer = gpt_answer_generator.generate_answer(chat_question.question)
else:
brainPicking = BrainPicking(
chat_id=str(chat_id),
@ -174,18 +188,64 @@ async def create_question_handler(
brain_id=brain_id,
user_openai_api_key=current_user.user_openai_api_key,
)
answer = brainPicking.generate_answer(chat_question.question)
chat_answer = update_chat_history(
chat_id=chat_id,
user_message=chat_question.question,
assistant_answer=answer,
assistant=answer,
)
return chat_answer
except HTTPException as e:
raise e
# stream new question response from chat
@chat_router.post(
"/chat/{chat_id}/question/stream",
dependencies=[Depends(AuthBearer())],
tags=["Chat"],
)
async def create_stream_question_handler(
request: Request,
chat_question: ChatQuestion,
chat_id: UUID,
brain_id: UUID = Query(..., description="The ID of the brain"),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
if (
os.getenv("PRIVATE") == "True"
or chat_question.model not in streaming_compatible_models
):
# forward the request to the none streaming endpoint create_question_handler function
return await create_question_handler(
request, chat_question, chat_id, current_user
)
try:
user_openai_api_key = request.headers.get("Openai-Api-Key")
check_user_limit(current_user)
brain = BrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
streaming=True,
)
return StreamingResponse(
brain.generate_stream(chat_question.question),
media_type="text/event-stream",
)
except HTTPException as e:
raise e
# get chat history
@chat_router.get(
"/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"]

View File

@ -1,121 +0,0 @@
import asyncio
import os
from typing import AsyncIterable, Awaitable
from uuid import UUID
from auth.auth_bearer import AuthBearer
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from logger import get_logger
from models.chats import ChatMessage
from models.settings import CommonsDep, common_dependencies
from vectorstore.supabase import CustomSupabaseVectorStore
from supabase import create_client
logger = get_logger(__name__)
stream_router = APIRouter()
openai_api_key = os.getenv("OPENAI_API_KEY")
supabase_url = os.getenv("SUPABASE_URL")
supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY")
async def send_message(
chat_message: ChatMessage, chain, callback
) -> AsyncIterable[str]:
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
resp = await fn
logger.debug("Done: %s", resp)
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
# Use the agenerate method for models.
# Use the acall method for chains.
task = asyncio.create_task(
wrap_done(
chain.acall(
{
"question": chat_message.question,
"chat_history": chat_message.history,
}
),
callback.done,
)
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter():
logger.info("Token: %s", token)
yield f"data: {token}\n\n"
await task
def create_chain(commons: CommonsDep, brain_id: UUID):
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
supabase_client = create_client(supabase_url, supabase_service_key)
vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", brain_id=brain_id
)
generator_llm = ChatOpenAI(
temperature=0,
)
# Callback provides the on_llm_new_token method
callback = AsyncIteratorCallbackHandler()
streaming_llm = ChatOpenAI(
temperature=0,
streaming=True,
callbacks=[callback],
)
question_generator = LLMChain(
llm=generator_llm,
prompt=CONDENSE_QUESTION_PROMPT,
)
doc_chain = load_qa_chain(
llm=streaming_llm,
chain_type="stuff",
)
return (
ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
question_generator=question_generator,
retriever=vector_store.as_retriever(),
verbose=True,
),
callback,
)
@stream_router.post("/stream", dependencies=[Depends(AuthBearer())], tags=["Stream"])
async def stream(
chat_message: ChatMessage,
brain_id: UUID = Query(..., description="The ID of the brain"),
) -> StreamingResponse:
commons = common_dependencies()
qa_chain, callback = create_chain(commons, brain_id)
return StreamingResponse(
send_message(chat_message, qa_chain, callback),
media_type="text/event-stream",
)

View File

@ -1,8 +0,0 @@
import os
from typing import Annotated
from fastapi import Depends
from logger import get_logger
from models.settings import common_dependencies
logger = get_logger(__name__)

View File

@ -0,0 +1,8 @@
openai_function_compatible_models = [
"gpt-3.5-turbo-0613",
"gpt-4-0613",
]
streaming_compatible_models = ["gpt-3.5-turbo"]
private_models = ["gpt4all-j-1.3"]

View File

@ -8,6 +8,8 @@ type ChatContextProps = {
history: ChatHistory[];
setHistory: (history: ChatHistory[]) => void;
addToHistory: (message: ChatHistory) => void;
updateHistory: (chat: ChatHistory) => void;
updateStreamingHistory: (streamedChat: ChatHistory) => void;
};
export const ChatContext = createContext<ChatContextProps | undefined>(
@ -20,16 +22,54 @@ export const ChatProvider = ({
children: JSX.Element | JSX.Element[];
}): JSX.Element => {
const [history, setHistory] = useState<ChatHistory[]>([]);
const addToHistory = (message: ChatHistory) => {
setHistory((prevHistory) => [...prevHistory, message]);
};
const updateStreamingHistory = (streamedChat: ChatHistory): void => {
setHistory((prevHistory: ChatHistory[]) => {
console.log("new chat", streamedChat);
const updatedHistory = prevHistory.find(
(item) => item.message_id === streamedChat.message_id
)
? prevHistory.map((item: ChatHistory) =>
item.message_id === streamedChat.message_id
? { ...item, assistant: item.assistant + streamedChat.assistant }
: item
)
: [...prevHistory, streamedChat];
console.log("updated history", updatedHistory);
return updatedHistory;
});
};
const updateHistory = (chat: ChatHistory): void => {
setHistory((prevHistory: ChatHistory[]) => {
const updatedHistory = prevHistory.find(
(item) => item.message_id === chat.message_id
)
? prevHistory.map((item: ChatHistory) =>
item.message_id === chat.message_id
? { ...item, assistant: chat.assistant }
: item
)
: [...prevHistory, chat];
return updatedHistory;
});
};
return (
<ChatContext.Provider
value={{
history,
setHistory,
addToHistory,
updateHistory,
updateStreamingHistory,
}}
>
{children}

View File

@ -1,3 +1,4 @@
/* eslint-disable max-lines */
import { AxiosError } from "axios";
import { useParams } from "next/navigation";
import { useEffect, useState } from "react";
@ -21,30 +22,34 @@ export const useChat = () => {
const {
config: { maxTokens, model, temperature },
} = useBrainConfig();
const { history, setHistory, addToHistory } = useChatContext();
const { history, setHistory } = useChatContext();
const { publish } = useToast();
const {
createChat,
getChatHistory,
addQuestion: addQuestionToChat,
addStreamQuestion,
addQuestion: addQuestionToModel,
} = useChatService();
useEffect(() => {
const fetchHistory = async () => {
const chatHistory = await getChatHistory(chatId);
setHistory(chatHistory);
const currentChatId = chatId;
const chatHistory = await getChatHistory(currentChatId);
if (chatId === currentChatId && chatHistory.length > 0) {
setHistory(chatHistory);
}
};
void fetchHistory();
}, [chatId]);
}, [chatId, getChatHistory, setHistory]);
const generateNewChatIdFromName = async (
chatName: string
): Promise<string> => {
const rep = await createChat({ name: chatName });
setChatId(rep.data.chat_id);
const chat = await createChat({ name: chatName });
return rep.data.chat_id;
return chat.chat_id;
};
const addQuestion = async (question: string, callback?: () => void) => {
@ -64,8 +69,15 @@ export const useChat = () => {
(await generateNewChatIdFromName(
question.split(" ").slice(0, 3).join(" ")
));
const answer = await addQuestionToChat(currentChatId, chatQuestion);
addToHistory(answer);
setChatId(currentChatId);
if (chatQuestion.model === "gpt-3.5-turbo") {
await addStreamQuestion(currentChatId, chatQuestion);
} else {
await addQuestionToModel(currentChatId, chatQuestion);
}
callback?.();
} catch (error) {
console.error({ error });
@ -88,5 +100,9 @@ export const useChat = () => {
}
};
return { history, addQuestion, generatingAnswer };
return {
history,
addQuestion,
generatingAnswer,
};
};

View File

@ -1,45 +1,137 @@
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { useAxios } from "@/lib/hooks";
/* eslint-disable max-lines */
import { useCallback } from "react";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { useAxios, useFetch } from "@/lib/hooks";
import { useChatContext } from "../context/ChatContext";
import { ChatEntity, ChatHistory, ChatQuestion } from "../types";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChatService = () => {
interface UseChatService {
createChat: (name: { name: string }) => Promise<ChatEntity>;
getChatHistory: (chatId: string | undefined) => Promise<ChatHistory[]>;
addQuestion: (chatId: string, chatQuestion: ChatQuestion) => Promise<void>;
addStreamQuestion: (
chatId: string,
chatQuestion: ChatQuestion
) => Promise<void>;
}
export const useChatService = (): UseChatService => {
const { axiosInstance } = useAxios();
const { fetchInstance } = useFetch();
const { updateHistory, updateStreamingHistory } = useChatContext();
const { currentBrain } = useBrainContext();
const createChat = async ({ name }: { name: string }) => {
return axiosInstance.post<ChatEntity>(`/chat`, { name });
const createChat = async ({
name,
}: {
name: string;
}): Promise<ChatEntity> => {
const response = (await axiosInstance.post<ChatEntity>(`/chat`, { name }))
.data;
return response;
};
const getChatHistory = async (chatId: string | undefined) => {
if (chatId === undefined) {
return [];
}
const rep = (
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
).data;
const getChatHistory = useCallback(
async (chatId: string | undefined): Promise<ChatHistory[]> => {
if (chatId === undefined) {
return [];
}
const response = (
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
).data;
return response;
},
[axiosInstance]
);
return rep;
};
const addQuestion = async (
chatId: string,
chatQuestion: ChatQuestion
): Promise<ChatHistory> => {
): Promise<void> => {
if (currentBrain?.id === undefined) {
throw new Error("No current brain");
}
return (
await axiosInstance.post<ChatHistory>(
`/chat/${chatId}/question?brain_id=${currentBrain.id}`,
chatQuestion
)
).data;
const response = await axiosInstance.post<ChatHistory>(
`/chat/${chatId}/question?brain_id=${currentBrain.id}`,
chatQuestion
);
updateHistory(response.data);
};
const handleStream = async (
reader: ReadableStreamDefaultReader<Uint8Array>
): Promise<void> => {
const decoder = new TextDecoder("utf-8");
const handleStreamRecursively = async () => {
const { done, value } = await reader.read();
if (done) {
return;
}
const dataStrings = decoder
.decode(value)
.trim()
.split("data: ")
.filter(Boolean);
dataStrings.forEach((data) => {
try {
const parsedData = JSON.parse(data) as ChatHistory;
updateStreamingHistory(parsedData);
} catch (error) {
console.error("Error parsing data:", error);
}
});
await handleStreamRecursively();
};
await handleStreamRecursively();
};
const addStreamQuestion = async (
chatId: string,
chatQuestion: ChatQuestion
): Promise<void> => {
if (currentBrain?.id === undefined) {
throw new Error("No current brain");
}
const headers = {
"Content-Type": "application/json",
Accept: "text/event-stream",
};
const body = JSON.stringify(chatQuestion);
try {
const response = await fetchInstance.post(
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,
body,
headers
);
if (response.body === null) {
throw new Error("Response body is null");
}
console.log("Received response. Starting to handle stream...");
await handleStream(response.body.getReader());
} catch (error) {
console.error("Error calling the API:", error);
}
};
return {
createChat,
getChatHistory,
addQuestion,
addStreamQuestion,
};
};

View File

@ -0,0 +1,13 @@
export const generateUUID = (): string => {
const array = new Uint32Array(4);
window.crypto.getRandomValues(array);
let idx = -1;
return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => {
idx++;
const r = (array[idx >> 3] >> ((idx % 8) * 4)) & 15;
const v = c === "x" ? r : (r & 0x3) | 0x8;
return v.toString(16);
});
};

View File

@ -1,2 +1,3 @@
export * from "./useAxios";
export * from "./useFetch";
export * from "./useToast";

View File

@ -0,0 +1,72 @@
import { useEffect, useState } from "react";
import { useBrainConfig } from "../context/BrainConfigProvider/hooks/useBrainConfig";
import { useSupabase } from "../context/SupabaseProvider";
interface FetchInstance {
get: (url: string, headers?: HeadersInit) => Promise<Response>;
post: (
url: string,
body: BodyInit | null | undefined,
headers?: HeadersInit
) => Promise<Response>;
put: (
url: string,
body: BodyInit | null | undefined,
headers?: HeadersInit
) => Promise<Response>;
delete: (url: string, headers?: HeadersInit) => Promise<Response>;
}
const fetchInstance: FetchInstance = {
get: async (url, headers) => fetch(url, { method: "GET", headers }),
post: async (url, body, headers) =>
fetch(url, { method: "POST", body, headers }),
put: async (url, body, headers) =>
fetch(url, { method: "PUT", body, headers }),
delete: async (url, headers) => fetch(url, { method: "DELETE", headers }),
};
export const useFetch = (): { fetchInstance: FetchInstance } => {
const { session } = useSupabase();
const {
config: { backendUrl: configBackendUrl, openAiKey },
} = useBrainConfig();
const [instance, setInstance] = useState(fetchInstance);
const baseURL = `${process.env.NEXT_PUBLIC_BACKEND_URL ?? ""}`;
const backendUrl = configBackendUrl ?? baseURL;
useEffect(() => {
setInstance({
...fetchInstance,
get: async (url, headers) =>
fetchInstance.get(`${backendUrl}${url}`, {
Authorization: `Bearer ${session?.access_token ?? ""}`,
"Openai-Api-Key": openAiKey ?? "",
...headers,
}),
post: async (url, body, headers) =>
fetchInstance.post(`${backendUrl}${url}`, body, {
Authorization: `Bearer ${session?.access_token ?? ""}`,
"Openai-Api-Key": openAiKey ?? "",
...headers,
}),
put: async (url, body, headers) =>
fetchInstance.put(`${backendUrl}${url}`, body, {
Authorization: `Bearer ${session?.access_token ?? ""}`,
"Openai-Api-Key": openAiKey ?? "",
...headers,
}),
delete: async (url, headers) =>
fetchInstance.delete(`${backendUrl}${url}`, {
Authorization: `Bearer ${session?.access_token ?? ""}`,
"Openai-Api-Key": openAiKey ?? "",
...headers,
}),
});
}, [session, backendUrl, openAiKey]);
return { fetchInstance: instance };
};