mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-09-11 14:36:35 +03:00
feat: merge chat history with chat notifications (#1127)
* feat: add chat_id to upload and crawl payload * feat(chat): return chat_history_with_notifications * feat: explicit notification status on create * feat: handle notifications in frontend * feat: delete chat notifications on chat delete request
This commit is contained in:
parent
575d9886c5
commit
9464707d40
@ -232,6 +232,10 @@ class Repository(ABC):
|
||||
def remove_notification_by_id(self, id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_notifications_by_chat_id(self, chat_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_notifications_by_chat_id(self, chat_id: UUID):
|
||||
pass
|
||||
|
@ -93,6 +93,19 @@ class Notifications(Repository):
|
||||
status="deleted", notification_id=notification_id
|
||||
)
|
||||
|
||||
def remove_notifications_by_chat_id(self, chat_id: UUID) -> None:
|
||||
"""
|
||||
Remove all notifications for a chat
|
||||
Args:
|
||||
chat_id (UUID): The id of the chat
|
||||
"""
|
||||
(
|
||||
self.db.from_("notifications")
|
||||
.delete()
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
def get_notifications_by_chat_id(self, chat_id: UUID) -> list[Notification]:
|
||||
"""
|
||||
Get all notifications for a chat
|
||||
@ -102,9 +115,11 @@ class Notifications(Repository):
|
||||
Returns:
|
||||
list[Notification]: The notifications
|
||||
"""
|
||||
return (
|
||||
notifications = (
|
||||
self.db.from_("notifications")
|
||||
.select("*")
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
return [Notification(**notification) for notification in notifications]
|
||||
|
@ -0,0 +1,57 @@
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
from uuid import UUID
|
||||
|
||||
from models.notifications import Notification
|
||||
from pydantic import BaseModel
|
||||
from utils.parse_message_time import (
|
||||
parse_message_time,
|
||||
)
|
||||
|
||||
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
|
||||
from repository.notification.get_chat_notifications import (
|
||||
get_chat_notifications,
|
||||
)
|
||||
|
||||
|
||||
class ChatItemType(Enum):
|
||||
MESSAGE = "MESSAGE"
|
||||
NOTIFICATION = "NOTIFICATION"
|
||||
|
||||
|
||||
class ChatItem(BaseModel):
|
||||
item_type: ChatItemType
|
||||
body: Union[GetChatHistoryOutput, Notification]
|
||||
|
||||
|
||||
def merge_chat_history_and_notifications(
|
||||
chat_history: List[GetChatHistoryOutput], notifications: List[Notification]
|
||||
) -> List[ChatItem]:
|
||||
chat_history_and_notifications = chat_history + notifications
|
||||
|
||||
chat_history_and_notifications.sort(
|
||||
key=lambda x: parse_message_time(x.message_time)
|
||||
if isinstance(x, GetChatHistoryOutput)
|
||||
else parse_message_time(x.datetime)
|
||||
)
|
||||
|
||||
transformed_data = []
|
||||
for item in chat_history_and_notifications:
|
||||
if isinstance(item, GetChatHistoryOutput):
|
||||
item_type = ChatItemType.MESSAGE
|
||||
body = item
|
||||
else:
|
||||
item_type = ChatItemType.NOTIFICATION
|
||||
body = item
|
||||
transformed_item = ChatItem(item_type=item_type, body=body)
|
||||
transformed_data.append(transformed_item)
|
||||
|
||||
return transformed_data
|
||||
|
||||
|
||||
def get_chat_history_with_notifications(
|
||||
chat_id: UUID,
|
||||
) -> List[ChatItem]:
|
||||
chat_history = get_chat_history(str(chat_id))
|
||||
chat_notifications = get_chat_notifications(chat_id)
|
||||
return merge_chat_history_and_notifications(chat_history, chat_notifications)
|
14
backend/repository/notification/get_chat_notifications.py
Normal file
14
backend/repository/notification/get_chat_notifications.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from models.notifications import Notification
|
||||
from models.settings import get_supabase_db
|
||||
|
||||
|
||||
def get_chat_notifications(chat_id: UUID) -> List[Notification]:
|
||||
"""
|
||||
Get notifications by chat_id
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
return supabase_db.get_notifications_by_chat_id(chat_id)
|
11
backend/repository/notification/remove_chat_notifications.py
Normal file
11
backend/repository/notification/remove_chat_notifications.py
Normal file
@ -0,0 +1,11 @@
|
||||
from uuid import UUID
|
||||
from models.settings import get_supabase_db
|
||||
|
||||
|
||||
def remove_chat_notifications(chat_id: UUID) -> None:
|
||||
"""
|
||||
Remove all notifications for a chat
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
supabase_db.remove_notifications_by_chat_id(chat_id)
|
@ -7,6 +7,9 @@ from venv import logger
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from repository.notification.remove_chat_notifications import (
|
||||
remove_chat_notifications,
|
||||
)
|
||||
from llm.openai import OpenAIBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from models import (
|
||||
@ -26,10 +29,13 @@ from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
create_chat,
|
||||
get_chat_by_id,
|
||||
get_chat_history,
|
||||
get_user_chats,
|
||||
update_chat,
|
||||
)
|
||||
from repository.chat.get_chat_history_with_notifications import (
|
||||
ChatItem,
|
||||
get_chat_history_with_notifications,
|
||||
)
|
||||
from repository.user_identity import get_user_identity
|
||||
|
||||
chat_router = APIRouter()
|
||||
@ -114,6 +120,8 @@ async def delete_chat(chat_id: UUID):
|
||||
Delete a specific chat by chat ID.
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
remove_chat_notifications(chat_id)
|
||||
|
||||
delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id)
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
|
||||
@ -333,6 +341,6 @@ async def create_stream_question_handler(
|
||||
)
|
||||
async def get_chat_history_handler(
|
||||
chat_id: UUID,
|
||||
) -> List[GetChatHistoryOutput]:
|
||||
) -> List[ChatItem]:
|
||||
# TODO: RBAC with current_user
|
||||
return get_chat_history(str(chat_id))
|
||||
return get_chat_history_with_notifications(chat_id)
|
||||
|
@ -33,6 +33,7 @@ async def crawl_endpoint(
|
||||
request: Request,
|
||||
crawl_website: CrawlWebsite,
|
||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||
chat_id: UUID = Query(..., description="The ID of the chat"),
|
||||
enable_summarization: bool = False,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
@ -56,11 +57,15 @@ async def crawl_endpoint(
|
||||
"type": "error",
|
||||
}
|
||||
else:
|
||||
crawl_notification = add_notification(
|
||||
CreateNotificationProperties(
|
||||
action="CRAWL",
|
||||
crawl_notification = None
|
||||
if chat_id:
|
||||
crawl_notification = add_notification(
|
||||
CreateNotificationProperties(
|
||||
action="CRAWL",
|
||||
chat_id=chat_id,
|
||||
status=NotificationsStatusEnum.Pending,
|
||||
)
|
||||
)
|
||||
)
|
||||
if not crawl_website.checkGithub():
|
||||
(
|
||||
file_path,
|
||||
@ -92,10 +97,11 @@ async def crawl_endpoint(
|
||||
brain_id=brain_id,
|
||||
user_openai_api_key=request.headers.get("Openai-Api-Key", None),
|
||||
)
|
||||
update_notification_by_id(
|
||||
crawl_notification.id,
|
||||
NotificationUpdatableProperties(
|
||||
status=NotificationsStatusEnum.Done, message=str(message)
|
||||
),
|
||||
)
|
||||
if crawl_notification:
|
||||
update_notification_by_id(
|
||||
crawl_notification.id,
|
||||
NotificationUpdatableProperties(
|
||||
status=NotificationsStatusEnum.Done, message=str(message)
|
||||
),
|
||||
)
|
||||
return message
|
||||
|
@ -36,6 +36,7 @@ async def upload_file(
|
||||
request: Request,
|
||||
uploadFile: UploadFile,
|
||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||
chat_id: UUID = Query(..., description="The ID of the chat"),
|
||||
enable_summarization: bool = False,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
@ -71,11 +72,15 @@ async def upload_file(
|
||||
"type": "error",
|
||||
}
|
||||
else:
|
||||
upload_notification = add_notification(
|
||||
CreateNotificationProperties(
|
||||
action="UPLOAD",
|
||||
upload_notification = None
|
||||
if chat_id:
|
||||
upload_notification = add_notification(
|
||||
CreateNotificationProperties(
|
||||
action="UPLOAD",
|
||||
chat_id=chat_id,
|
||||
status=NotificationsStatusEnum.Pending,
|
||||
)
|
||||
)
|
||||
)
|
||||
openai_api_key = request.headers.get("Openai-Api-Key", None)
|
||||
if openai_api_key is None:
|
||||
brain_details = get_brain_details(brain_id)
|
||||
@ -91,11 +96,13 @@ async def upload_file(
|
||||
brain_id=brain_id,
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
update_notification_by_id(
|
||||
upload_notification.id,
|
||||
NotificationUpdatableProperties(
|
||||
status=NotificationsStatusEnum.Done, message=str(message)
|
||||
),
|
||||
)
|
||||
|
||||
if upload_notification:
|
||||
update_notification_by_id(
|
||||
upload_notification.id,
|
||||
NotificationUpdatableProperties(
|
||||
status=NotificationsStatusEnum.Done, message=str(message)
|
||||
),
|
||||
)
|
||||
|
||||
return message
|
||||
|
5
backend/utils/parse_message_time.py
Normal file
5
backend/utils/parse_message_time.py
Normal file
@ -0,0 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def parse_message_time(message_time_str):
|
||||
return datetime.strptime(message_time_str, "%Y-%m-%dT%H:%M:%S.%f")
|
@ -1,9 +1,11 @@
|
||||
/* eslint-disable max-lines */
|
||||
import axios from "axios";
|
||||
import { UUID } from "crypto";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useCallback, useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { useChatApi } from "@/lib/api/chat/useChatApi";
|
||||
import { useCrawlApi } from "@/lib/api/crawl/useCrawlApi";
|
||||
import { useUploadApi } from "@/lib/api/upload/useUploadApi";
|
||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||
@ -18,6 +20,10 @@ export const useKnowledgeUploader = () => {
|
||||
const { uploadFile } = useUploadApi();
|
||||
const { t } = useTranslation(["upload"]);
|
||||
const { crawlWebsiteUrl } = useCrawlApi();
|
||||
const { createChat } = useChatApi();
|
||||
|
||||
const params = useParams();
|
||||
const chatId = params?.chatId as UUID | undefined;
|
||||
|
||||
const { currentBrainId } = useBrainContext();
|
||||
const addContent = (content: FeedItemType) => {
|
||||
@ -28,7 +34,7 @@ export const useKnowledgeUploader = () => {
|
||||
};
|
||||
|
||||
const crawlWebsiteHandler = useCallback(
|
||||
async (url: string, brainId: UUID) => {
|
||||
async (url: string, brainId: UUID, chat_id: UUID) => {
|
||||
// Configure parameters
|
||||
const config = {
|
||||
url: url,
|
||||
@ -42,6 +48,7 @@ export const useKnowledgeUploader = () => {
|
||||
await crawlWebsiteUrl({
|
||||
brainId,
|
||||
config,
|
||||
chat_id,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
publish({
|
||||
@ -56,13 +63,14 @@ export const useKnowledgeUploader = () => {
|
||||
);
|
||||
|
||||
const uploadFileHandler = useCallback(
|
||||
async (file: File, brainId: UUID) => {
|
||||
async (file: File, brainId: UUID, chat_id: UUID) => {
|
||||
const formData = new FormData();
|
||||
formData.append("uploadFile", file);
|
||||
try {
|
||||
await uploadFile({
|
||||
brainId: brainId,
|
||||
brainId,
|
||||
formData,
|
||||
chat_id,
|
||||
});
|
||||
} catch (e: unknown) {
|
||||
if (axios.isAxiosError(e) && e.response?.status === 403) {
|
||||
@ -104,12 +112,14 @@ export const useKnowledgeUploader = () => {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const currentChatId = chatId ?? (await createChat("New Chat")).chat_id;
|
||||
const uploadPromises = files.map((file) =>
|
||||
uploadFileHandler(file, currentBrainId)
|
||||
uploadFileHandler(file, currentBrainId, currentChatId)
|
||||
);
|
||||
const crawlPromises = urls.map((url) =>
|
||||
crawlWebsiteHandler(url, currentBrainId)
|
||||
crawlWebsiteHandler(url, currentBrainId, currentChatId)
|
||||
);
|
||||
|
||||
await Promise.all([...uploadPromises, ...crawlPromises]);
|
||||
|
@ -4,6 +4,8 @@ import { useEffect } from "react";
|
||||
import { useChatApi } from "@/lib/api/chat/useChatApi";
|
||||
import { useChatContext } from "@/lib/context";
|
||||
|
||||
import { getMessagesFromChatHistory } from "../utils/getMessagesFromChatHistory";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||
export const useSelectedChatPage = () => {
|
||||
const { setHistory } = useChatContext();
|
||||
@ -23,7 +25,7 @@ export const useSelectedChatPage = () => {
|
||||
const chatHistory = await getHistory(chatId);
|
||||
|
||||
if (chatHistory.length > 0) {
|
||||
setHistory(chatHistory);
|
||||
setHistory(getMessagesFromChatHistory(chatHistory));
|
||||
}
|
||||
};
|
||||
void fetchHistory();
|
||||
|
@ -18,6 +18,22 @@ export type ChatHistory = {
|
||||
brain_name?: string;
|
||||
};
|
||||
|
||||
type HistoryItemType = "MESSAGE" | "NOTIFICATION";
|
||||
|
||||
type Notification = {
|
||||
id: string;
|
||||
datetime: string;
|
||||
chat_id?: string | null;
|
||||
message?: string | null;
|
||||
action: string;
|
||||
status: string;
|
||||
};
|
||||
|
||||
export type ChatItem = {
|
||||
item_type: HistoryItemType;
|
||||
body: ChatHistory | Notification;
|
||||
};
|
||||
|
||||
export type ChatEntity = {
|
||||
chat_id: UUID;
|
||||
user_id: string;
|
||||
|
@ -0,0 +1,11 @@
|
||||
import { ChatHistory, ChatItem } from "../types";
|
||||
|
||||
export const getMessagesFromChatHistory = (
|
||||
chatHistory: ChatItem[]
|
||||
): ChatHistory[] => {
|
||||
const messages = chatHistory
|
||||
.filter((item) => item.item_type === "MESSAGE")
|
||||
.map((item) => item.body as ChatHistory);
|
||||
|
||||
return messages;
|
||||
};
|
@ -3,6 +3,7 @@ import { AxiosInstance } from "axios";
|
||||
import {
|
||||
ChatEntity,
|
||||
ChatHistory,
|
||||
ChatItem,
|
||||
ChatQuestion,
|
||||
} from "@/app/chat/[chatId]/types";
|
||||
|
||||
@ -55,8 +56,8 @@ export const addQuestion = async (
|
||||
export const getHistory = async (
|
||||
chatId: string,
|
||||
axiosInstance: AxiosInstance
|
||||
): Promise<ChatHistory[]> =>
|
||||
(await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)).data;
|
||||
): Promise<ChatItem[]> =>
|
||||
(await axiosInstance.get<ChatItem[]>(`/chat/${chatId}/history`)).data;
|
||||
|
||||
export type ChatUpdatableProperties = {
|
||||
chat_name?: string;
|
||||
|
@ -24,6 +24,7 @@ describe("useCrawlApi", () => {
|
||||
} = renderHook(() => useCrawlApi());
|
||||
const crawlInputProps: CrawlInputProps = {
|
||||
brainId: "e7001ccd-6d90-4eab-8c50-2f23d39441e4",
|
||||
chat_id: "e7001ccd-6d90-4eab-8c50-2f23d39441es",
|
||||
config: {
|
||||
url: "https://en.wikipedia.org/wiki/Mali",
|
||||
js: false,
|
||||
@ -36,7 +37,9 @@ describe("useCrawlApi", () => {
|
||||
|
||||
expect(axiosPostMock).toHaveBeenCalledTimes(1);
|
||||
expect(axiosPostMock).toHaveBeenCalledWith(
|
||||
`/crawl?brain_id=${crawlInputProps.brainId}`,
|
||||
`/crawl?brain_id=${crawlInputProps.brainId}&chat_id=${
|
||||
crawlInputProps.chat_id ?? ""
|
||||
}`,
|
||||
crawlInputProps.config
|
||||
);
|
||||
});
|
||||
|
@ -5,6 +5,7 @@ import { ToastData } from "@/lib/components/ui/Toast/domain/types";
|
||||
|
||||
export type CrawlInputProps = {
|
||||
brainId: UUID;
|
||||
chat_id?: UUID;
|
||||
config: {
|
||||
url: string;
|
||||
js: boolean;
|
||||
@ -22,4 +23,7 @@ export const crawlWebsiteUrl = async (
|
||||
props: CrawlInputProps,
|
||||
axiosInstance: AxiosInstance
|
||||
): Promise<CrawlResponse> =>
|
||||
axiosInstance.post(`/crawl?brain_id=${props.brainId}`, props.config);
|
||||
axiosInstance.post(
|
||||
`/crawl?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`,
|
||||
props.config
|
||||
);
|
||||
|
@ -10,10 +10,14 @@ export type UploadResponse = {
|
||||
export type UploadInputProps = {
|
||||
brainId: UUID;
|
||||
formData: FormData;
|
||||
chat_id?: UUID;
|
||||
};
|
||||
|
||||
export const uploadFile = async (
|
||||
props: UploadInputProps,
|
||||
axiosInstance: AxiosInstance
|
||||
): Promise<UploadResponse> =>
|
||||
axiosInstance.post(`/upload?brain_id=${props.brainId}`, props.formData);
|
||||
axiosInstance.post(
|
||||
`/upload?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`,
|
||||
props.formData
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user