mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat: streaming for standard brain picking (#385)
* feat: streaming for standard brain picking * fix(bug): private llm * wip: test Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com> * wip: almost good Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com> * feat: useFetch * chore: remove 💀 * chore: fix linting * fix: forward the request if not streaming * feat: streaming for standard brain picking * fix(bug): private llm * wip: test Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com> * wip: almost good Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com> * feat: useFetch * chore: remove 💀 * chore: fix linting * fix: forward the request if not streaming * fix: 💀 code * fix: check_user_limit * feat: brain_id to new chat stream * fix: missing imports * feat: message_id created on backend Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com> * chore: remove dead * remove: cpython * remove: dead --------- Co-authored-by: Mamadou DICKO <mamadoudicko@users.noreply.github.com>
This commit is contained in:
parent
056a68d5ed
commit
6f047f4a39
@ -1,9 +1,9 @@
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.llms.base import LLM
|
||||
from llm.brainpicking import BrainPicking
|
||||
from logger import get_logger
|
||||
from models.settings import LLMSettings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -13,6 +13,9 @@ class PrivateBrainPicking(BrainPicking):
|
||||
This subclass of BrainPicking is used to specifically work with a private language model.
|
||||
"""
|
||||
|
||||
# Initialize class settings
|
||||
llm_settings = LLMSettings()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@ -28,7 +31,7 @@ class PrivateBrainPicking(BrainPicking):
|
||||
:param brain_id: The user id to be used for CustomSupabaseVectorStore.
|
||||
:return: PrivateBrainPicking instance
|
||||
"""
|
||||
# Call the parent class's initializer
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
@ -38,20 +41,17 @@ class PrivateBrainPicking(BrainPicking):
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
|
||||
def _determine_llm(
|
||||
self, private_model_args: dict, private: bool = True, model_name: str = None
|
||||
) -> LLM:
|
||||
def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM:
|
||||
"""
|
||||
Override the _determine_llm method to enforce the use of a private model.
|
||||
Override the _create_llm method to enforce the use of a private model.
|
||||
:param model_name: Language model name to be used.
|
||||
:param private_model_args: Dictionary containing model_path, n_ctx and n_batch.
|
||||
:param private: Boolean value to determine if private model is to be used. Defaulted to True.
|
||||
:return: Language model instance
|
||||
"""
|
||||
# Force the use of a private model by setting private to True.
|
||||
model_path = private_model_args["model_path"]
|
||||
model_n_ctx = private_model_args["n_ctx"]
|
||||
model_n_batch = private_model_args["n_batch"]
|
||||
model_path = self.llm_settings.model_path
|
||||
model_n_ctx = self.llm_settings.model_n_ctx
|
||||
model_n_batch = self.llm_settings.model_n_batch
|
||||
|
||||
logger.info("Using private model: %s", model_path)
|
||||
|
||||
|
@ -1,4 +1,8 @@
|
||||
from typing import Any, Dict
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, Awaitable
|
||||
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
@ -6,69 +10,53 @@ from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from logger import get_logger
|
||||
from models.settings import \
|
||||
BrainSettings # Importing settings related to the 'brain'
|
||||
from models.settings import LLMSettings # For type hinting
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from vectorstore.supabase import \
|
||||
CustomSupabaseVectorStore # Custom class for handling vector storage with Supabase
|
||||
|
||||
from repository.chat.update_chat_history import update_chat_history
|
||||
from repository.chat.update_message_by_id import update_message_by_id
|
||||
from supabase import Client # For interacting with Supabase database
|
||||
from supabase import create_client
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnswerConversationBufferMemory(ConversationBufferMemory):
|
||||
"""
|
||||
This class is a specialized version of ConversationBufferMemory.
|
||||
It overrides the save_context method to save the response using the 'answer' key in the outputs.
|
||||
Reference to some issue comment is given in the docstring.
|
||||
"""
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
# Overriding the save_context method of the parent class
|
||||
return super(AnswerConversationBufferMemory, self).save_context(
|
||||
inputs, {"response": outputs["answer"]}
|
||||
)
|
||||
|
||||
|
||||
def format_chat_history(inputs) -> str:
|
||||
"""
|
||||
Function to concatenate chat history into a single string.
|
||||
:param inputs: List of tuples containing human and AI messages.
|
||||
:return: concatenated string of chat history
|
||||
"""
|
||||
res = []
|
||||
for human, ai in inputs:
|
||||
res.append(f"{human}:{ai}\n")
|
||||
return "\n".join(res)
|
||||
|
||||
|
||||
class BrainPicking(BaseModel):
|
||||
"""
|
||||
Main class for the Brain Picking functionality.
|
||||
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
||||
"""
|
||||
|
||||
# Instantiate settings
|
||||
settings = BrainSettings()
|
||||
|
||||
# Default class attributes
|
||||
llm_name: str = "gpt-3.5-turbo"
|
||||
temperature: float = 0.0
|
||||
settings = BrainSettings()
|
||||
llm_config = LLMSettings()
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
supabase_client: Client = None
|
||||
vector_store: CustomSupabaseVectorStore = None
|
||||
llm: LLM = None
|
||||
question_generator: LLMChain = None
|
||||
doc_chain: ConversationalRetrievalChain = None
|
||||
chat_id: str
|
||||
max_tokens: int = 256
|
||||
|
||||
# Storage
|
||||
supabase_client: Client = None
|
||||
vector_store: CustomSupabaseVectorStore = None
|
||||
|
||||
# Language models
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
question_llm: LLM = None
|
||||
doc_llm: LLM = None
|
||||
question_generator: LLMChain = None
|
||||
doc_chain: LLMChain = None
|
||||
qa: ConversationalRetrievalChain = None
|
||||
|
||||
# Streaming
|
||||
callback: AsyncIteratorCallbackHandler = None
|
||||
streaming: bool = False
|
||||
|
||||
class Config:
|
||||
# Allowing arbitrary types for class validation
|
||||
arbitrary_types_allowed = True
|
||||
@ -81,6 +69,7 @@ class BrainPicking(BaseModel):
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
streaming: bool = False,
|
||||
) -> "BrainPicking":
|
||||
"""
|
||||
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
||||
@ -113,25 +102,38 @@ class BrainPicking(BaseModel):
|
||||
brain_id=brain_id,
|
||||
)
|
||||
|
||||
self.llm = self._determine_llm(
|
||||
private_model_args={
|
||||
"model_path": self.llm_config.model_path,
|
||||
"n_ctx": self.llm_config.model_n_ctx,
|
||||
"n_batch": self.llm_config.model_n_batch,
|
||||
},
|
||||
private=self.llm_config.private,
|
||||
self.question_llm = self._create_llm(
|
||||
model_name=self.llm_name,
|
||||
streaming=False,
|
||||
)
|
||||
self.question_generator = LLMChain(
|
||||
llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT
|
||||
llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT
|
||||
)
|
||||
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
||||
|
||||
if streaming:
|
||||
self.callback = AsyncIteratorCallbackHandler()
|
||||
self.doc_llm = self._create_llm(
|
||||
model_name=self.llm_name,
|
||||
streaming=streaming,
|
||||
callbacks=[self.callback],
|
||||
)
|
||||
self.doc_chain = load_qa_chain(
|
||||
llm=self.doc_llm,
|
||||
chain_type="stuff",
|
||||
)
|
||||
self.streaming = streaming
|
||||
else:
|
||||
self.doc_llm = self._create_llm(
|
||||
model_name=self.llm_name,
|
||||
streaming=streaming,
|
||||
)
|
||||
self.doc_chain = load_qa_chain(llm=self.doc_llm, chain_type="stuff")
|
||||
self.streaming = streaming
|
||||
|
||||
self.chat_id = chat_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _determine_llm(
|
||||
self, private_model_args: dict, private: bool = False, model_name: str = None
|
||||
) -> LLM:
|
||||
def _create_llm(self, model_name, streaming=False, callbacks=None) -> LLM:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
:param model_name: Language model name to be used.
|
||||
@ -139,8 +141,12 @@ class BrainPicking(BaseModel):
|
||||
:param private: Boolean value to determine if private model is to be used.
|
||||
:return: Language model instance
|
||||
"""
|
||||
|
||||
return ChatOpenAI(temperature=0, model_name=model_name)
|
||||
return ChatOpenAI(
|
||||
temperature=0,
|
||||
model_name=model_name,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def _get_qa(
|
||||
self,
|
||||
@ -155,11 +161,11 @@ class BrainPicking(BaseModel):
|
||||
# Initialize and return a ConversationalRetrievalChain
|
||||
qa = ConversationalRetrievalChain(
|
||||
retriever=self.vector_store.as_retriever(),
|
||||
max_tokens_limit=self.max_tokens,
|
||||
question_generator=self.question_generator,
|
||||
combine_docs_chain=self.doc_chain,
|
||||
get_chat_history=format_chat_history,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return qa
|
||||
|
||||
def generate_answer(self, question: str) -> str:
|
||||
@ -182,3 +188,70 @@ class BrainPicking(BaseModel):
|
||||
answer = model_response["answer"]
|
||||
|
||||
return answer
|
||||
|
||||
async def generate_stream(self, question: str) -> AsyncIterable:
|
||||
"""
|
||||
Generate a streaming answer to a given question by interacting with the language model.
|
||||
:param question: The question
|
||||
:return: An async iterable which generates the answer.
|
||||
"""
|
||||
|
||||
# Get the QA chain
|
||||
qa = self._get_qa()
|
||||
history = get_chat_history(self.chat_id)
|
||||
callback = self.callback
|
||||
|
||||
# # Format the chat history into a list of tuples (human, ai)
|
||||
transformed_history = [(chat.user_message, chat.assistant) for chat in history]
|
||||
|
||||
# Initialize a list to hold the tokens
|
||||
response_tokens = []
|
||||
|
||||
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
logger.error(f"Caught exception: {e}")
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
# Use the acall method to perform an async call to the QA chain
|
||||
task = asyncio.create_task(
|
||||
wrap_done(
|
||||
qa.acall(
|
||||
{
|
||||
"question": question,
|
||||
"chat_history": transformed_history,
|
||||
}
|
||||
),
|
||||
callback.done,
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = update_chat_history(
|
||||
chat_id=self.chat_id,
|
||||
user_message=question,
|
||||
assistant="",
|
||||
)
|
||||
|
||||
# Use the aiter method of the callback to stream the response with server-sent-events
|
||||
async for token in callback.aiter():
|
||||
logger.info("Token: %s", token)
|
||||
|
||||
# Add the token to the response_tokens list
|
||||
response_tokens.append(token)
|
||||
streamed_chat_history.assistant = token
|
||||
|
||||
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
|
||||
|
||||
await task
|
||||
|
||||
# Join the tokens to create the assistant's response
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
update_message_by_id(
|
||||
message_id=streamed_chat_history.message_id,
|
||||
user_message=question,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
@ -10,7 +10,6 @@ from routes.chat_routes import chat_router
|
||||
from routes.crawl_routes import crawl_router
|
||||
from routes.explore_routes import explore_router
|
||||
from routes.misc_routes import misc_router
|
||||
from routes.stream_routes import stream_router
|
||||
from routes.upload_routes import upload_router
|
||||
from routes.user_routes import user_router
|
||||
|
||||
@ -35,7 +34,6 @@ app.include_router(misc_router)
|
||||
app.include_router(upload_router)
|
||||
app.include_router(user_router)
|
||||
app.include_router(api_key_router)
|
||||
app.include_router(stream_router)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -29,3 +29,6 @@ class ChatHistory:
|
||||
self.user_message = chat_dict.get("user_message")
|
||||
self.assistant = chat_dict.get("assistant")
|
||||
self.message_time = chat_dict.get("message_time")
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
@ -1,11 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from logger import get_logger
|
||||
from models.chat import Chat
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -1,12 +1,11 @@
|
||||
from typing import List # For type hinting
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models.chat import ChatHistory
|
||||
from models.settings import common_dependencies
|
||||
from typing import List # For type hinting
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def update_chat_history(
|
||||
chat_id: str, user_message: str, assistant_answer: str
|
||||
) -> ChatHistory:
|
||||
def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory:
|
||||
commons = common_dependencies()
|
||||
response: List[ChatHistory] = (
|
||||
commons["supabase"]
|
||||
@ -15,7 +14,7 @@ def update_chat_history(
|
||||
{
|
||||
"chat_id": str(chat_id),
|
||||
"user_message": user_message,
|
||||
"assistant": assistant_answer,
|
||||
"assistant": assistant,
|
||||
}
|
||||
)
|
||||
.execute()
|
||||
@ -24,4 +23,4 @@ def update_chat_history(
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An exception occurred while updating chat history."
|
||||
)
|
||||
return response[0]
|
||||
return ChatHistory(response[0])
|
||||
|
38
backend/repository/chat/update_message_by_id.py
Normal file
38
backend/repository/chat/update_message_by_id.py
Normal file
@ -0,0 +1,38 @@
|
||||
from logger import get_logger
|
||||
from models.chat import ChatHistory
|
||||
from models.settings import common_dependencies
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def update_message_by_id(
|
||||
message_id: str, user_message: str, assistant: str
|
||||
) -> ChatHistory:
|
||||
commons = common_dependencies()
|
||||
|
||||
if not message_id:
|
||||
logger.error("No message_id provided")
|
||||
return
|
||||
|
||||
updates = {}
|
||||
|
||||
if user_message is not None:
|
||||
updates["user_message"] = user_message
|
||||
|
||||
if assistant is not None:
|
||||
updates["assistant"] = user_message
|
||||
|
||||
updated_message = None
|
||||
|
||||
if updates:
|
||||
updated_message = (
|
||||
commons["supabase"]
|
||||
.table("chat_history")
|
||||
.update(updates)
|
||||
.match({"message_id": message_id})
|
||||
.execute()
|
||||
).data[0]
|
||||
logger.info(f"Message {message_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for message {message_id}")
|
||||
return ChatHistory(updated_message)
|
@ -5,7 +5,8 @@ from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llm.brainpicking import BrainPicking
|
||||
from llm.BrainPickingOpenAIFunctions.BrainPickingOpenAIFunctions import (
|
||||
BrainPickingOpenAIFunctions,
|
||||
@ -21,6 +22,10 @@ from repository.chat.get_chat_history import get_chat_history
|
||||
from repository.chat.get_user_chats import get_user_chats
|
||||
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
|
||||
from repository.chat.update_chat_history import update_chat_history
|
||||
from utils.constants import (
|
||||
openai_function_compatible_models,
|
||||
streaming_compatible_models,
|
||||
)
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
@ -40,6 +45,36 @@ def delete_chat_from_db(commons, chat_id):
|
||||
commons["supabase"].table("chats").delete().match({"chat_id": chat_id}).execute()
|
||||
|
||||
|
||||
def fetch_user_stats(commons, user, date):
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("users")
|
||||
.select("*")
|
||||
.filter("email", "eq", user.email)
|
||||
.filter("date", "eq", date)
|
||||
.execute()
|
||||
)
|
||||
userItem = next(iter(response.data or []), {"requests_count": 0})
|
||||
return userItem
|
||||
|
||||
|
||||
def check_user_limit(
|
||||
user: User,
|
||||
):
|
||||
if user.user_openai_api_key is None:
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
|
||||
user.increment_user_request_count(date)
|
||||
if user.requests_count >= float(max_requests_number):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="You have reached the maximum number of requests for today.",
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
# get all chats
|
||||
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def get_chats(current_user: User = Depends(get_current_user)):
|
||||
@ -52,7 +87,6 @@ async def get_chats(current_user: User = Depends(get_current_user)):
|
||||
This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects
|
||||
containing the chat ID and chat name for each chat.
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
chats = get_user_chats(current_user.id)
|
||||
return {"chats": chats}
|
||||
|
||||
@ -82,7 +116,6 @@ async def update_chat_metadata_handler(
|
||||
"""
|
||||
Update chat attributes
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
|
||||
chat = get_chat_by_id(chat_id)
|
||||
if current_user.id != chat.user_id:
|
||||
@ -92,24 +125,6 @@ async def update_chat_metadata_handler(
|
||||
return update_chat(chat_id=chat_id, chat_data=chat_data)
|
||||
|
||||
|
||||
# helper method for update and create chat
|
||||
def check_user_limit(
|
||||
user: User,
|
||||
):
|
||||
if user.user_openai_api_key is None:
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
|
||||
user.increment_user_request_count(date)
|
||||
if user.requests_count >= float(max_requests_number):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="You have reached the maximum number of requests for today.",
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
# create new chat
|
||||
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def create_chat_handler(
|
||||
@ -139,10 +154,7 @@ async def create_question_handler(
|
||||
try:
|
||||
check_user_limit(current_user)
|
||||
llm_settings = LLMSettings()
|
||||
openai_function_compatible_models = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0613",
|
||||
]
|
||||
|
||||
if llm_settings.private:
|
||||
gpt_answer_generator = PrivateBrainPicking(
|
||||
model=chat_question.model,
|
||||
@ -153,6 +165,7 @@ async def create_question_handler(
|
||||
user_openai_api_key=current_user.user_openai_api_key,
|
||||
)
|
||||
answer = gpt_answer_generator.generate_answer(chat_question.question)
|
||||
|
||||
elif chat_question.model in openai_function_compatible_models:
|
||||
# TODO: RBAC with current_user
|
||||
gpt_answer_generator = BrainPickingOpenAIFunctions(
|
||||
@ -165,6 +178,7 @@ async def create_question_handler(
|
||||
user_openai_api_key=current_user.user_openai_api_key,
|
||||
)
|
||||
answer = gpt_answer_generator.generate_answer(chat_question.question)
|
||||
|
||||
else:
|
||||
brainPicking = BrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
@ -174,18 +188,64 @@ async def create_question_handler(
|
||||
brain_id=brain_id,
|
||||
user_openai_api_key=current_user.user_openai_api_key,
|
||||
)
|
||||
|
||||
answer = brainPicking.generate_answer(chat_question.question)
|
||||
|
||||
chat_answer = update_chat_history(
|
||||
chat_id=chat_id,
|
||||
user_message=chat_question.question,
|
||||
assistant_answer=answer,
|
||||
assistant=answer,
|
||||
)
|
||||
return chat_answer
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
|
||||
# stream new question response from chat
|
||||
@chat_router.post(
|
||||
"/chat/{chat_id}/question/stream",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Chat"],
|
||||
)
|
||||
async def create_stream_question_handler(
|
||||
request: Request,
|
||||
chat_question: ChatQuestion,
|
||||
chat_id: UUID,
|
||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
if (
|
||||
os.getenv("PRIVATE") == "True"
|
||||
or chat_question.model not in streaming_compatible_models
|
||||
):
|
||||
# forward the request to the none streaming endpoint create_question_handler function
|
||||
return await create_question_handler(
|
||||
request, chat_question, chat_id, current_user
|
||||
)
|
||||
|
||||
try:
|
||||
user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
check_user_limit(current_user)
|
||||
|
||||
brain = BrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature,
|
||||
brain_id=brain_id,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
brain.generate_stream(chat_question.question),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
|
||||
# get chat history
|
||||
@chat_router.get(
|
||||
"/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"]
|
||||
|
@ -1,121 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncIterable, Awaitable
|
||||
from uuid import UUID
|
||||
|
||||
from auth.auth_bearer import AuthBearer
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from logger import get_logger
|
||||
from models.chats import ChatMessage
|
||||
from models.settings import CommonsDep, common_dependencies
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from supabase import create_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
stream_router = APIRouter()
|
||||
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
supabase_url = os.getenv("SUPABASE_URL")
|
||||
supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY")
|
||||
|
||||
|
||||
async def send_message(
|
||||
chat_message: ChatMessage, chain, callback
|
||||
) -> AsyncIterable[str]:
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||
try:
|
||||
resp = await fn
|
||||
logger.debug("Done: %s", resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Caught exception: {e}")
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
|
||||
# Use the agenerate method for models.
|
||||
# Use the acall method for chains.
|
||||
task = asyncio.create_task(
|
||||
wrap_done(
|
||||
chain.acall(
|
||||
{
|
||||
"question": chat_message.question,
|
||||
"chat_history": chat_message.history,
|
||||
}
|
||||
),
|
||||
callback.done,
|
||||
)
|
||||
)
|
||||
|
||||
# Use the aiter method of the callback to stream the response with server-sent-events
|
||||
async for token in callback.aiter():
|
||||
logger.info("Token: %s", token)
|
||||
yield f"data: {token}\n\n"
|
||||
|
||||
await task
|
||||
|
||||
|
||||
def create_chain(commons: CommonsDep, brain_id: UUID):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
||||
|
||||
supabase_client = create_client(supabase_url, supabase_service_key)
|
||||
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="vectors", brain_id=brain_id
|
||||
)
|
||||
|
||||
generator_llm = ChatOpenAI(
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Callback provides the on_llm_new_token method
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
streaming_llm = ChatOpenAI(
|
||||
temperature=0,
|
||||
streaming=True,
|
||||
callbacks=[callback],
|
||||
)
|
||||
question_generator = LLMChain(
|
||||
llm=generator_llm,
|
||||
prompt=CONDENSE_QUESTION_PROMPT,
|
||||
)
|
||||
doc_chain = load_qa_chain(
|
||||
llm=streaming_llm,
|
||||
chain_type="stuff",
|
||||
)
|
||||
|
||||
return (
|
||||
ConversationalRetrievalChain(
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=question_generator,
|
||||
retriever=vector_store.as_retriever(),
|
||||
verbose=True,
|
||||
),
|
||||
callback,
|
||||
)
|
||||
|
||||
|
||||
@stream_router.post("/stream", dependencies=[Depends(AuthBearer())], tags=["Stream"])
|
||||
async def stream(
|
||||
chat_message: ChatMessage,
|
||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||
) -> StreamingResponse:
|
||||
commons = common_dependencies()
|
||||
|
||||
qa_chain, callback = create_chain(commons, brain_id)
|
||||
|
||||
return StreamingResponse(
|
||||
send_message(chat_message, qa_chain, callback),
|
||||
media_type="text/event-stream",
|
||||
)
|
@ -1,8 +0,0 @@
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from logger import get_logger
|
||||
from models.settings import common_dependencies
|
||||
|
||||
logger = get_logger(__name__)
|
8
backend/utils/constants.py
Normal file
8
backend/utils/constants.py
Normal file
@ -0,0 +1,8 @@
|
||||
openai_function_compatible_models = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0613",
|
||||
]
|
||||
|
||||
streaming_compatible_models = ["gpt-3.5-turbo"]
|
||||
|
||||
private_models = ["gpt4all-j-1.3"]
|
@ -8,6 +8,8 @@ type ChatContextProps = {
|
||||
history: ChatHistory[];
|
||||
setHistory: (history: ChatHistory[]) => void;
|
||||
addToHistory: (message: ChatHistory) => void;
|
||||
updateHistory: (chat: ChatHistory) => void;
|
||||
updateStreamingHistory: (streamedChat: ChatHistory) => void;
|
||||
};
|
||||
|
||||
export const ChatContext = createContext<ChatContextProps | undefined>(
|
||||
@ -20,16 +22,54 @@ export const ChatProvider = ({
|
||||
children: JSX.Element | JSX.Element[];
|
||||
}): JSX.Element => {
|
||||
const [history, setHistory] = useState<ChatHistory[]>([]);
|
||||
|
||||
const addToHistory = (message: ChatHistory) => {
|
||||
setHistory((prevHistory) => [...prevHistory, message]);
|
||||
};
|
||||
|
||||
const updateStreamingHistory = (streamedChat: ChatHistory): void => {
|
||||
setHistory((prevHistory: ChatHistory[]) => {
|
||||
console.log("new chat", streamedChat);
|
||||
const updatedHistory = prevHistory.find(
|
||||
(item) => item.message_id === streamedChat.message_id
|
||||
)
|
||||
? prevHistory.map((item: ChatHistory) =>
|
||||
item.message_id === streamedChat.message_id
|
||||
? { ...item, assistant: item.assistant + streamedChat.assistant }
|
||||
: item
|
||||
)
|
||||
: [...prevHistory, streamedChat];
|
||||
|
||||
console.log("updated history", updatedHistory);
|
||||
|
||||
return updatedHistory;
|
||||
});
|
||||
};
|
||||
|
||||
const updateHistory = (chat: ChatHistory): void => {
|
||||
setHistory((prevHistory: ChatHistory[]) => {
|
||||
const updatedHistory = prevHistory.find(
|
||||
(item) => item.message_id === chat.message_id
|
||||
)
|
||||
? prevHistory.map((item: ChatHistory) =>
|
||||
item.message_id === chat.message_id
|
||||
? { ...item, assistant: chat.assistant }
|
||||
: item
|
||||
)
|
||||
: [...prevHistory, chat];
|
||||
|
||||
return updatedHistory;
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<ChatContext.Provider
|
||||
value={{
|
||||
history,
|
||||
setHistory,
|
||||
addToHistory,
|
||||
updateHistory,
|
||||
updateStreamingHistory,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
@ -1,3 +1,4 @@
|
||||
/* eslint-disable max-lines */
|
||||
import { AxiosError } from "axios";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
@ -21,30 +22,34 @@ export const useChat = () => {
|
||||
const {
|
||||
config: { maxTokens, model, temperature },
|
||||
} = useBrainConfig();
|
||||
const { history, setHistory, addToHistory } = useChatContext();
|
||||
const { history, setHistory } = useChatContext();
|
||||
const { publish } = useToast();
|
||||
|
||||
const {
|
||||
createChat,
|
||||
getChatHistory,
|
||||
addQuestion: addQuestionToChat,
|
||||
addStreamQuestion,
|
||||
addQuestion: addQuestionToModel,
|
||||
} = useChatService();
|
||||
|
||||
useEffect(() => {
|
||||
const fetchHistory = async () => {
|
||||
const chatHistory = await getChatHistory(chatId);
|
||||
setHistory(chatHistory);
|
||||
const currentChatId = chatId;
|
||||
const chatHistory = await getChatHistory(currentChatId);
|
||||
|
||||
if (chatId === currentChatId && chatHistory.length > 0) {
|
||||
setHistory(chatHistory);
|
||||
}
|
||||
};
|
||||
void fetchHistory();
|
||||
}, [chatId]);
|
||||
}, [chatId, getChatHistory, setHistory]);
|
||||
|
||||
const generateNewChatIdFromName = async (
|
||||
chatName: string
|
||||
): Promise<string> => {
|
||||
const rep = await createChat({ name: chatName });
|
||||
setChatId(rep.data.chat_id);
|
||||
const chat = await createChat({ name: chatName });
|
||||
|
||||
return rep.data.chat_id;
|
||||
return chat.chat_id;
|
||||
};
|
||||
|
||||
const addQuestion = async (question: string, callback?: () => void) => {
|
||||
@ -64,8 +69,15 @@ export const useChat = () => {
|
||||
(await generateNewChatIdFromName(
|
||||
question.split(" ").slice(0, 3).join(" ")
|
||||
));
|
||||
const answer = await addQuestionToChat(currentChatId, chatQuestion);
|
||||
addToHistory(answer);
|
||||
|
||||
setChatId(currentChatId);
|
||||
|
||||
if (chatQuestion.model === "gpt-3.5-turbo") {
|
||||
await addStreamQuestion(currentChatId, chatQuestion);
|
||||
} else {
|
||||
await addQuestionToModel(currentChatId, chatQuestion);
|
||||
}
|
||||
|
||||
callback?.();
|
||||
} catch (error) {
|
||||
console.error({ error });
|
||||
@ -88,5 +100,9 @@ export const useChat = () => {
|
||||
}
|
||||
};
|
||||
|
||||
return { history, addQuestion, generatingAnswer };
|
||||
return {
|
||||
history,
|
||||
addQuestion,
|
||||
generatingAnswer,
|
||||
};
|
||||
};
|
||||
|
@ -1,45 +1,137 @@
|
||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||
import { useAxios } from "@/lib/hooks";
|
||||
/* eslint-disable max-lines */
|
||||
|
||||
import { useCallback } from "react";
|
||||
|
||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||
import { useAxios, useFetch } from "@/lib/hooks";
|
||||
|
||||
import { useChatContext } from "../context/ChatContext";
|
||||
import { ChatEntity, ChatHistory, ChatQuestion } from "../types";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||
export const useChatService = () => {
|
||||
interface UseChatService {
|
||||
createChat: (name: { name: string }) => Promise<ChatEntity>;
|
||||
getChatHistory: (chatId: string | undefined) => Promise<ChatHistory[]>;
|
||||
addQuestion: (chatId: string, chatQuestion: ChatQuestion) => Promise<void>;
|
||||
addStreamQuestion: (
|
||||
chatId: string,
|
||||
chatQuestion: ChatQuestion
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
export const useChatService = (): UseChatService => {
|
||||
const { axiosInstance } = useAxios();
|
||||
const { fetchInstance } = useFetch();
|
||||
const { updateHistory, updateStreamingHistory } = useChatContext();
|
||||
const { currentBrain } = useBrainContext();
|
||||
const createChat = async ({ name }: { name: string }) => {
|
||||
return axiosInstance.post<ChatEntity>(`/chat`, { name });
|
||||
const createChat = async ({
|
||||
name,
|
||||
}: {
|
||||
name: string;
|
||||
}): Promise<ChatEntity> => {
|
||||
const response = (await axiosInstance.post<ChatEntity>(`/chat`, { name }))
|
||||
.data;
|
||||
|
||||
return response;
|
||||
};
|
||||
|
||||
const getChatHistory = async (chatId: string | undefined) => {
|
||||
if (chatId === undefined) {
|
||||
return [];
|
||||
}
|
||||
const rep = (
|
||||
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
|
||||
).data;
|
||||
const getChatHistory = useCallback(
|
||||
async (chatId: string | undefined): Promise<ChatHistory[]> => {
|
||||
if (chatId === undefined) {
|
||||
return [];
|
||||
}
|
||||
const response = (
|
||||
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
|
||||
).data;
|
||||
|
||||
return response;
|
||||
},
|
||||
[axiosInstance]
|
||||
);
|
||||
|
||||
return rep;
|
||||
};
|
||||
const addQuestion = async (
|
||||
chatId: string,
|
||||
chatQuestion: ChatQuestion
|
||||
): Promise<ChatHistory> => {
|
||||
): Promise<void> => {
|
||||
if (currentBrain?.id === undefined) {
|
||||
throw new Error("No current brain");
|
||||
}
|
||||
|
||||
return (
|
||||
await axiosInstance.post<ChatHistory>(
|
||||
`/chat/${chatId}/question?brain_id=${currentBrain.id}`,
|
||||
chatQuestion
|
||||
)
|
||||
).data;
|
||||
const response = await axiosInstance.post<ChatHistory>(
|
||||
`/chat/${chatId}/question?brain_id=${currentBrain.id}`,
|
||||
chatQuestion
|
||||
);
|
||||
|
||||
updateHistory(response.data);
|
||||
};
|
||||
|
||||
const handleStream = async (
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>
|
||||
): Promise<void> => {
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
|
||||
const handleStreamRecursively = async () => {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
|
||||
const dataStrings = decoder
|
||||
.decode(value)
|
||||
.trim()
|
||||
.split("data: ")
|
||||
.filter(Boolean);
|
||||
|
||||
dataStrings.forEach((data) => {
|
||||
try {
|
||||
const parsedData = JSON.parse(data) as ChatHistory;
|
||||
updateStreamingHistory(parsedData);
|
||||
} catch (error) {
|
||||
console.error("Error parsing data:", error);
|
||||
}
|
||||
});
|
||||
|
||||
await handleStreamRecursively();
|
||||
};
|
||||
|
||||
await handleStreamRecursively();
|
||||
};
|
||||
|
||||
const addStreamQuestion = async (
|
||||
chatId: string,
|
||||
chatQuestion: ChatQuestion
|
||||
): Promise<void> => {
|
||||
if (currentBrain?.id === undefined) {
|
||||
throw new Error("No current brain");
|
||||
}
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
};
|
||||
const body = JSON.stringify(chatQuestion);
|
||||
|
||||
try {
|
||||
const response = await fetchInstance.post(
|
||||
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,
|
||||
body,
|
||||
headers
|
||||
);
|
||||
|
||||
if (response.body === null) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
console.log("Received response. Starting to handle stream...");
|
||||
await handleStream(response.body.getReader());
|
||||
} catch (error) {
|
||||
console.error("Error calling the API:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
createChat,
|
||||
getChatHistory,
|
||||
addQuestion,
|
||||
addStreamQuestion,
|
||||
};
|
||||
};
|
||||
|
13
frontend/lib/helpers/uuid.ts
Normal file
13
frontend/lib/helpers/uuid.ts
Normal file
@ -0,0 +1,13 @@
|
||||
export const generateUUID = (): string => {
|
||||
const array = new Uint32Array(4);
|
||||
window.crypto.getRandomValues(array);
|
||||
let idx = -1;
|
||||
|
||||
return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => {
|
||||
idx++;
|
||||
const r = (array[idx >> 3] >> ((idx % 8) * 4)) & 15;
|
||||
const v = c === "x" ? r : (r & 0x3) | 0x8;
|
||||
|
||||
return v.toString(16);
|
||||
});
|
||||
};
|
@ -1,2 +1,3 @@
|
||||
export * from "./useAxios";
|
||||
export * from "./useFetch";
|
||||
export * from "./useToast";
|
||||
|
72
frontend/lib/hooks/useFetch.ts
Normal file
72
frontend/lib/hooks/useFetch.ts
Normal file
@ -0,0 +1,72 @@
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { useBrainConfig } from "../context/BrainConfigProvider/hooks/useBrainConfig";
|
||||
import { useSupabase } from "../context/SupabaseProvider";
|
||||
|
||||
interface FetchInstance {
|
||||
get: (url: string, headers?: HeadersInit) => Promise<Response>;
|
||||
post: (
|
||||
url: string,
|
||||
body: BodyInit | null | undefined,
|
||||
headers?: HeadersInit
|
||||
) => Promise<Response>;
|
||||
put: (
|
||||
url: string,
|
||||
body: BodyInit | null | undefined,
|
||||
headers?: HeadersInit
|
||||
) => Promise<Response>;
|
||||
delete: (url: string, headers?: HeadersInit) => Promise<Response>;
|
||||
}
|
||||
|
||||
const fetchInstance: FetchInstance = {
|
||||
get: async (url, headers) => fetch(url, { method: "GET", headers }),
|
||||
post: async (url, body, headers) =>
|
||||
fetch(url, { method: "POST", body, headers }),
|
||||
put: async (url, body, headers) =>
|
||||
fetch(url, { method: "PUT", body, headers }),
|
||||
delete: async (url, headers) => fetch(url, { method: "DELETE", headers }),
|
||||
};
|
||||
|
||||
export const useFetch = (): { fetchInstance: FetchInstance } => {
|
||||
const { session } = useSupabase();
|
||||
const {
|
||||
config: { backendUrl: configBackendUrl, openAiKey },
|
||||
} = useBrainConfig();
|
||||
|
||||
const [instance, setInstance] = useState(fetchInstance);
|
||||
|
||||
const baseURL = `${process.env.NEXT_PUBLIC_BACKEND_URL ?? ""}`;
|
||||
const backendUrl = configBackendUrl ?? baseURL;
|
||||
|
||||
useEffect(() => {
|
||||
setInstance({
|
||||
...fetchInstance,
|
||||
get: async (url, headers) =>
|
||||
fetchInstance.get(`${backendUrl}${url}`, {
|
||||
Authorization: `Bearer ${session?.access_token ?? ""}`,
|
||||
"Openai-Api-Key": openAiKey ?? "",
|
||||
...headers,
|
||||
}),
|
||||
post: async (url, body, headers) =>
|
||||
fetchInstance.post(`${backendUrl}${url}`, body, {
|
||||
Authorization: `Bearer ${session?.access_token ?? ""}`,
|
||||
"Openai-Api-Key": openAiKey ?? "",
|
||||
...headers,
|
||||
}),
|
||||
put: async (url, body, headers) =>
|
||||
fetchInstance.put(`${backendUrl}${url}`, body, {
|
||||
Authorization: `Bearer ${session?.access_token ?? ""}`,
|
||||
"Openai-Api-Key": openAiKey ?? "",
|
||||
...headers,
|
||||
}),
|
||||
delete: async (url, headers) =>
|
||||
fetchInstance.delete(`${backendUrl}${url}`, {
|
||||
Authorization: `Bearer ${session?.access_token ?? ""}`,
|
||||
"Openai-Api-Key": openAiKey ?? "",
|
||||
...headers,
|
||||
}),
|
||||
});
|
||||
}, [session, backendUrl, openAiKey]);
|
||||
|
||||
return { fetchInstance: instance };
|
||||
};
|
Loading…
Reference in New Issue
Block a user