mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-24 11:52:45 +03:00
feat(Unplug): chatting without brain streaming (#970)
* feat(Unplug): Adds new basic headless llm * feat(Unplug): adds chatting without brain option when no streaming * feat(Unplug): adds chatting without brain option when streaming
This commit is contained in:
parent
7281fd905a
commit
600ff1ede0
@ -1,9 +1,11 @@
|
|||||||
from .base import BaseBrainPicking
|
from .base import BaseBrainPicking
|
||||||
from .qa_base import QABaseBrainPicking
|
from .qa_base import QABaseBrainPicking
|
||||||
from .openai import OpenAIBrainPicking
|
from .openai import OpenAIBrainPicking
|
||||||
|
from .qa_headless import HeadlessQA
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseBrainPicking",
|
"BaseBrainPicking",
|
||||||
"QABaseBrainPicking",
|
"QABaseBrainPicking",
|
||||||
"OpenAIBrainPicking",
|
"OpenAIBrainPicking",
|
||||||
|
"HeadlessQA"
|
||||||
]
|
]
|
||||||
|
207
backend/core/llm/qa_headless.py
Normal file
207
backend/core/llm/qa_headless.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from repository.chat.update_message_by_id import update_message_by_id
|
||||||
|
from models.databases.supabase.chats import CreateChatHistory
|
||||||
|
from repository.chat.format_chat_history import format_chat_history
|
||||||
|
from repository.chat.get_chat_history import get_chat_history
|
||||||
|
from repository.chat.update_chat_history import update_chat_history
|
||||||
|
from repository.chat.format_chat_history import format_history_to_openai_mesages
|
||||||
|
from logger import get_logger
|
||||||
|
from models.chats import ChatQuestion
|
||||||
|
from repository.chat.get_chat_history import GetChatHistoryOutput
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from typing import AsyncIterable, Awaitable, List
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||||
|
|
||||||
|
|
||||||
|
class HeadlessQA(BaseModel):
|
||||||
|
model: str = None # type: ignore
|
||||||
|
temperature: float = 0.0
|
||||||
|
max_tokens: int = 256
|
||||||
|
user_openai_api_key: str = None # type: ignore
|
||||||
|
openai_api_key: str = None # type: ignore
|
||||||
|
streaming: bool = False
|
||||||
|
chat_id: str = None # type: ignore
|
||||||
|
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore
|
||||||
|
|
||||||
|
def _determine_api_key(self, openai_api_key, user_openai_api_key):
|
||||||
|
"""If user provided an API key, use it."""
|
||||||
|
if user_openai_api_key is not None:
|
||||||
|
return user_openai_api_key
|
||||||
|
else:
|
||||||
|
return openai_api_key
|
||||||
|
|
||||||
|
def _determine_streaming(self, model: str, streaming: bool) -> bool:
|
||||||
|
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
|
||||||
|
return streaming
|
||||||
|
|
||||||
|
def _determine_callback_array(
|
||||||
|
self, streaming
|
||||||
|
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
|
||||||
|
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
|
||||||
|
if streaming:
|
||||||
|
return [
|
||||||
|
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
self.openai_api_key = self._determine_api_key(
|
||||||
|
self.openai_api_key, self.user_openai_api_key
|
||||||
|
)
|
||||||
|
self.streaming = self._determine_streaming(
|
||||||
|
self.model, self.streaming
|
||||||
|
) # pyright: ignore reportPrivateUsage=none
|
||||||
|
self.callbacks = self._determine_callback_array(
|
||||||
|
self.streaming
|
||||||
|
) # pyright: ignore reportPrivateUsage=none
|
||||||
|
|
||||||
|
def _create_llm(
|
||||||
|
self, model, temperature=0, streaming=False, callbacks=None
|
||||||
|
) -> BaseLLM:
|
||||||
|
"""
|
||||||
|
Determine the language model to be used.
|
||||||
|
:param model: Language model name to be used.
|
||||||
|
:param streaming: Whether to enable streaming of the model
|
||||||
|
:param callbacks: Callbacks to be used for streaming
|
||||||
|
:return: Language model instance
|
||||||
|
"""
|
||||||
|
return ChatOpenAI(
|
||||||
|
temperature=temperature,
|
||||||
|
model=model,
|
||||||
|
streaming=streaming,
|
||||||
|
verbose=True,
|
||||||
|
callbacks=callbacks,
|
||||||
|
openai_api_key=self.openai_api_key,
|
||||||
|
) # pyright: ignore reportPrivateUsage=none
|
||||||
|
|
||||||
|
def _create_prompt_template(self):
|
||||||
|
messages = [
|
||||||
|
HumanMessagePromptTemplate.from_template("{question}"),
|
||||||
|
]
|
||||||
|
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||||
|
return CHAT_PROMPT
|
||||||
|
|
||||||
|
def generate_answer(
|
||||||
|
self, chat_id: UUID, question: ChatQuestion
|
||||||
|
) -> GetChatHistoryOutput:
|
||||||
|
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||||
|
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
|
||||||
|
answering_llm = self._create_llm(
|
||||||
|
model=self.model, streaming=False, callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
model_prediction = answering_llm.predict_messages(messages) # pyright: ignore reportPrivateUsage=none
|
||||||
|
answer = model_prediction.content
|
||||||
|
|
||||||
|
new_chat = update_chat_history(
|
||||||
|
CreateChatHistory(
|
||||||
|
**{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"user_message": question.question,
|
||||||
|
"assistant": answer,
|
||||||
|
"brain_id": None,
|
||||||
|
"prompt_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GetChatHistoryOutput(
|
||||||
|
**{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"user_message": question.question,
|
||||||
|
"assistant": answer,
|
||||||
|
"message_time": new_chat.message_time,
|
||||||
|
"prompt_title": None,
|
||||||
|
"brain_name": None,
|
||||||
|
"message_id": new_chat.message_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self, chat_id: UUID, question: ChatQuestion
|
||||||
|
) -> AsyncIterable:
|
||||||
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
self.callbacks = [callback]
|
||||||
|
|
||||||
|
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||||
|
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
|
||||||
|
answering_llm = self._create_llm(
|
||||||
|
model=self.model, streaming=True, callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||||
|
headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT)
|
||||||
|
|
||||||
|
response_tokens = []
|
||||||
|
|
||||||
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
try:
|
||||||
|
await fn
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Caught exception: {e}")
|
||||||
|
finally:
|
||||||
|
event.set()
|
||||||
|
run = asyncio.create_task(
|
||||||
|
wrap_done(
|
||||||
|
headlessChain.acall({}),
|
||||||
|
callback.done,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
streamed_chat_history = update_chat_history(
|
||||||
|
CreateChatHistory(
|
||||||
|
**{
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"user_message": question.question,
|
||||||
|
"assistant": "",
|
||||||
|
"brain_id": None,
|
||||||
|
"prompt_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
streamed_chat_history = GetChatHistoryOutput(
|
||||||
|
**{
|
||||||
|
"chat_id": str(chat_id),
|
||||||
|
"message_id": streamed_chat_history.message_id,
|
||||||
|
"message_time": streamed_chat_history.message_time,
|
||||||
|
"user_message": question.question,
|
||||||
|
"assistant": "",
|
||||||
|
"prompt_title": None,
|
||||||
|
"brain_name": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async for token in callback.aiter():
|
||||||
|
logger.info("Token: %s", token) # type: ignore
|
||||||
|
response_tokens.append(token) # type: ignore
|
||||||
|
streamed_chat_history.assistant = token # type: ignore
|
||||||
|
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||||
|
|
||||||
|
await run
|
||||||
|
assistant = "".join(response_tokens)
|
||||||
|
|
||||||
|
update_message_by_id(
|
||||||
|
message_id=str(streamed_chat_history.message_id),
|
||||||
|
user_message=question.question,
|
||||||
|
assistant=assistant,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
@ -1,4 +1,19 @@
|
|||||||
def format_chat_history(history) -> list[tuple[str, str]]:
|
from typing import List, Tuple
|
||||||
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
def format_chat_history(history) -> List[Tuple[str, str]]:
|
||||||
"""Format the chat history into a list of tuples (human, ai)"""
|
"""Format the chat history into a list of tuples (human, ai)"""
|
||||||
|
|
||||||
return [(chat.user_message, chat.assistant) for chat in history]
|
return [(chat.user_message, chat.assistant) for chat in history]
|
||||||
|
|
||||||
|
|
||||||
|
def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]:
|
||||||
|
"""Format the chat history into a list of Base Messages"""
|
||||||
|
messages = []
|
||||||
|
messages.append(SystemMessage(content=system_message))
|
||||||
|
for human, ai in tuple_history:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
|
messages.append(HumanMessage(content=question))
|
||||||
|
return messages
|
||||||
|
@ -7,6 +7,7 @@ from venv import logger
|
|||||||
from auth import AuthBearer, get_current_user
|
from auth import AuthBearer, get_current_user
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from llm.qa_headless import HeadlessQA
|
||||||
from llm.openai import OpenAIBrainPicking
|
from llm.openai import OpenAIBrainPicking
|
||||||
from models.brains import Brain
|
from models.brains import Brain
|
||||||
from models.brain_entity import BrainEntity
|
from models.brain_entity import BrainEntity
|
||||||
@ -16,9 +17,6 @@ from models.databases.supabase.supabase import SupabaseDB
|
|||||||
from models.settings import LLMSettings, get_supabase_db
|
from models.settings import LLMSettings, get_supabase_db
|
||||||
from models.users import User
|
from models.users import User
|
||||||
from repository.brain.get_brain_details import get_brain_details
|
from repository.brain.get_brain_details import get_brain_details
|
||||||
from repository.brain.get_default_user_brain_or_create_new import (
|
|
||||||
get_default_user_brain_or_create_new,
|
|
||||||
)
|
|
||||||
from repository.chat.create_chat import CreateChatProperties, create_chat
|
from repository.chat.create_chat import CreateChatProperties, create_chat
|
||||||
from repository.chat.get_chat_by_id import get_chat_by_id
|
from repository.chat.get_chat_by_id import get_chat_by_id
|
||||||
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
|
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
|
||||||
@ -190,17 +188,24 @@ async def create_question_handler(
|
|||||||
check_user_limit(current_user)
|
check_user_limit(current_user)
|
||||||
LLMSettings()
|
LLMSettings()
|
||||||
|
|
||||||
if not brain_id:
|
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
|
||||||
brain_id = get_default_user_brain_or_create_new(current_user).brain_id
|
if brain_id:
|
||||||
|
gpt_answer_generator = OpenAIBrainPicking(
|
||||||
gpt_answer_generator = OpenAIBrainPicking(
|
chat_id=str(chat_id),
|
||||||
chat_id=str(chat_id),
|
model=chat_question.model,
|
||||||
model=chat_question.model,
|
max_tokens=chat_question.max_tokens,
|
||||||
max_tokens=chat_question.max_tokens,
|
temperature=chat_question.temperature,
|
||||||
temperature=chat_question.temperature,
|
brain_id=str(brain_id),
|
||||||
brain_id=str(brain_id),
|
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
)
|
||||||
)
|
else:
|
||||||
|
gpt_answer_generator = HeadlessQA(
|
||||||
|
model=chat_question.model,
|
||||||
|
temperature=chat_question.temperature,
|
||||||
|
max_tokens=chat_question.max_tokens,
|
||||||
|
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
|
chat_id=str(chat_id),
|
||||||
|
)
|
||||||
|
|
||||||
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)
|
||||||
|
|
||||||
@ -259,18 +264,26 @@ async def create_stream_question_handler(
|
|||||||
try:
|
try:
|
||||||
logger.info(f"Streaming request for {chat_question.model}")
|
logger.info(f"Streaming request for {chat_question.model}")
|
||||||
check_user_limit(current_user)
|
check_user_limit(current_user)
|
||||||
if not brain_id:
|
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
|
||||||
brain_id = get_default_user_brain_or_create_new(current_user).brain_id
|
if brain_id:
|
||||||
|
gpt_answer_generator = OpenAIBrainPicking(
|
||||||
gpt_answer_generator = OpenAIBrainPicking(
|
chat_id=str(chat_id),
|
||||||
chat_id=str(chat_id),
|
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
|
||||||
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
|
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
|
||||||
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
|
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
|
||||||
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
|
brain_id=str(brain_id),
|
||||||
brain_id=str(brain_id),
|
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
streaming=True,
|
||||||
streaming=True,
|
)
|
||||||
)
|
else:
|
||||||
|
gpt_answer_generator = HeadlessQA(
|
||||||
|
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
|
||||||
|
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
|
||||||
|
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
|
||||||
|
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||||
|
chat_id=str(chat_id),
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
print("streaming")
|
print("streaming")
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
/* eslint-disable max-lines */
|
/* eslint-disable max-lines */
|
||||||
import { AxiosError } from "axios";
|
import { AxiosError } from 'axios';
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from 'next/navigation';
|
||||||
import { useState } from "react";
|
import { useState } from 'react';
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { getChatConfigFromLocalStorage } from "@/lib/api/chat/chat.local";
|
import { getChatConfigFromLocalStorage } from '@/lib/api/chat/chat.local';
|
||||||
import { useChatApi } from "@/lib/api/chat/useChatApi";
|
import { useChatApi } from '@/lib/api/chat/useChatApi';
|
||||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext';
|
||||||
import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext";
|
import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext';
|
||||||
import { useToast } from "@/lib/hooks";
|
import { useToast } from '@/lib/hooks';
|
||||||
import { useEventTracking } from "@/services/analytics/useEventTracking";
|
import { useEventTracking } from '@/services/analytics/useEventTracking';
|
||||||
|
|
||||||
import { useQuestion } from "./useQuestion";
|
import { useQuestion } from './useQuestion';
|
||||||
import { ChatQuestion } from "../types";
|
import { ChatQuestion } from '../types';
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||||
export const useChat = () => {
|
export const useChat = () => {
|
||||||
@ -29,7 +29,7 @@ export const useChat = () => {
|
|||||||
const { createChat } = useChatApi();
|
const { createChat } = useChatApi();
|
||||||
|
|
||||||
const { addStreamQuestion } = useQuestion();
|
const { addStreamQuestion } = useQuestion();
|
||||||
const { t } = useTranslation(["chat"]);
|
const { t } = useTranslation(['chat']);
|
||||||
|
|
||||||
const addQuestion = async (question: string, callback?: () => void) => {
|
const addQuestion = async (question: string, callback?: () => void) => {
|
||||||
try {
|
try {
|
||||||
@ -39,22 +39,13 @@ export const useChat = () => {
|
|||||||
|
|
||||||
//if chatId is not set, create a new chat. Chat name is from the first question
|
//if chatId is not set, create a new chat. Chat name is from the first question
|
||||||
if (currentChatId === undefined) {
|
if (currentChatId === undefined) {
|
||||||
const chatName = question.split(" ").slice(0, 3).join(" ");
|
const chatName = question.split(' ').slice(0, 3).join(' ');
|
||||||
const chat = await createChat(chatName);
|
const chat = await createChat(chatName);
|
||||||
currentChatId = chat.chat_id;
|
currentChatId = chat.chat_id;
|
||||||
setChatId(currentChatId);
|
setChatId(currentChatId);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (currentBrain?.id === undefined) {
|
void track('QUESTION_ASKED');
|
||||||
publish({
|
|
||||||
variant: "danger",
|
|
||||||
text: t("missing_brain"),
|
|
||||||
});
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
void track("QUESTION_ASKED");
|
|
||||||
const chatConfig = getChatConfigFromLocalStorage(currentChatId);
|
const chatConfig = getChatConfigFromLocalStorage(currentChatId);
|
||||||
|
|
||||||
const chatQuestion: ChatQuestion = {
|
const chatQuestion: ChatQuestion = {
|
||||||
@ -62,7 +53,7 @@ export const useChat = () => {
|
|||||||
question,
|
question,
|
||||||
temperature: chatConfig?.temperature,
|
temperature: chatConfig?.temperature,
|
||||||
max_tokens: chatConfig?.maxTokens,
|
max_tokens: chatConfig?.maxTokens,
|
||||||
brain_id: currentBrain.id,
|
brain_id: currentBrain?.id,
|
||||||
};
|
};
|
||||||
|
|
||||||
await addStreamQuestion(currentChatId, chatQuestion);
|
await addStreamQuestion(currentChatId, chatQuestion);
|
||||||
@ -73,16 +64,16 @@ export const useChat = () => {
|
|||||||
|
|
||||||
if ((error as AxiosError).response?.status === 429) {
|
if ((error as AxiosError).response?.status === 429) {
|
||||||
publish({
|
publish({
|
||||||
variant: "danger",
|
variant: 'danger',
|
||||||
text: t("limit_reached", { ns: "chat" }),
|
text: t('limit_reached', { ns: 'chat' }),
|
||||||
});
|
});
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
publish({
|
publish({
|
||||||
variant: "danger",
|
variant: 'danger',
|
||||||
text: t("error_occurred", { ns: "chat" }),
|
text: t('error_occurred', { ns: 'chat' }),
|
||||||
});
|
});
|
||||||
} finally {
|
} finally {
|
||||||
setGeneratingAnswer(false);
|
setGeneratingAnswer(false);
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
/* eslint-disable max-lines */
|
/* eslint-disable max-lines */
|
||||||
|
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
import { useBrainContext } from '@/lib/context/BrainProvider/hooks/useBrainContext';
|
||||||
import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext";
|
import { useChatContext } from '@/lib/context/ChatProvider/hooks/useChatContext';
|
||||||
import { useFetch } from "@/lib/hooks";
|
import { useFetch } from '@/lib/hooks';
|
||||||
|
|
||||||
import { ChatHistory, ChatQuestion } from "../types";
|
import { ChatHistory, ChatQuestion } from '../types';
|
||||||
|
|
||||||
interface UseChatService {
|
interface UseChatService {
|
||||||
addStreamQuestion: (
|
addStreamQuestion: (
|
||||||
@ -20,12 +20,12 @@ export const useQuestion = (): UseChatService => {
|
|||||||
const { updateStreamingHistory } = useChatContext();
|
const { updateStreamingHistory } = useChatContext();
|
||||||
const { currentBrain } = useBrainContext();
|
const { currentBrain } = useBrainContext();
|
||||||
|
|
||||||
const { t } = useTranslation(["chat"]);
|
const { t } = useTranslation(['chat']);
|
||||||
|
|
||||||
const handleStream = async (
|
const handleStream = async (
|
||||||
reader: ReadableStreamDefaultReader<Uint8Array>
|
reader: ReadableStreamDefaultReader<Uint8Array>
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const decoder = new TextDecoder("utf-8");
|
const decoder = new TextDecoder('utf-8');
|
||||||
|
|
||||||
const handleStreamRecursively = async () => {
|
const handleStreamRecursively = async () => {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
@ -37,7 +37,7 @@ export const useQuestion = (): UseChatService => {
|
|||||||
const dataStrings = decoder
|
const dataStrings = decoder
|
||||||
.decode(value)
|
.decode(value)
|
||||||
.trim()
|
.trim()
|
||||||
.split("data: ")
|
.split('data: ')
|
||||||
.filter(Boolean);
|
.filter(Boolean);
|
||||||
|
|
||||||
dataStrings.forEach((data) => {
|
dataStrings.forEach((data) => {
|
||||||
@ -45,7 +45,7 @@ export const useQuestion = (): UseChatService => {
|
|||||||
const parsedData = JSON.parse(data) as ChatHistory;
|
const parsedData = JSON.parse(data) as ChatHistory;
|
||||||
updateStreamingHistory(parsedData);
|
updateStreamingHistory(parsedData);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(t("errorParsingData", { ns: "chat" }), error);
|
console.error(t('errorParsingData', { ns: 'chat' }), error);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -59,30 +59,27 @@ export const useQuestion = (): UseChatService => {
|
|||||||
chatId: string,
|
chatId: string,
|
||||||
chatQuestion: ChatQuestion
|
chatQuestion: ChatQuestion
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
if (currentBrain?.id === undefined) {
|
|
||||||
throw new Error(t("noCurrentBrain", { ns: "chat" }));
|
|
||||||
}
|
|
||||||
const headers = {
|
const headers = {
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
Accept: "text/event-stream",
|
Accept: 'text/event-stream',
|
||||||
};
|
};
|
||||||
const body = JSON.stringify(chatQuestion);
|
const body = JSON.stringify(chatQuestion);
|
||||||
console.log("Calling API...");
|
console.log('Calling API...');
|
||||||
try {
|
try {
|
||||||
const response = await fetchInstance.post(
|
const response = await fetchInstance.post(
|
||||||
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,
|
`/chat/${chatId}/question/stream?brain_id=${currentBrain?.id ?? ''}`,
|
||||||
body,
|
body,
|
||||||
headers
|
headers
|
||||||
);
|
);
|
||||||
|
|
||||||
if (response.body === null) {
|
if (response.body === null) {
|
||||||
throw new Error(t("resposeBodyNull", { ns: "chat" }));
|
throw new Error(t('resposeBodyNull', { ns: 'chat' }));
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(t("receivedResponse"), response);
|
console.log(t('receivedResponse'), response);
|
||||||
await handleStream(response.body.getReader());
|
await handleStream(response.body.getReader());
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(t("errorCallingAPI", { ns: "chat" }), error);
|
console.error(t('errorCallingAPI', { ns: 'chat' }), error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user