feat: merge chat history with chat notifications (#1127)

* feat: add chat_id to upload and crawl payload

* feat(chat): return chat_history_with_notifications

* feat: explicit notification status on create

* feat: handle notifications in frontend

* feat: delete chat notifications on chat delete request
This commit is contained in:
Mamadou DICKO 2023-09-07 17:23:31 +02:00 committed by GitHub
parent 575d9886c5
commit 9464707d40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 213 additions and 35 deletions

View File

@ -232,6 +232,10 @@ class Repository(ABC):
def remove_notification_by_id(self, id: UUID):
pass
@abstractmethod
def remove_notifications_by_chat_id(self, chat_id: UUID):
pass
@abstractmethod
def get_notifications_by_chat_id(self, chat_id: UUID):
pass

View File

@ -93,6 +93,19 @@ class Notifications(Repository):
status="deleted", notification_id=notification_id
)
def remove_notifications_by_chat_id(self, chat_id: UUID) -> None:
"""
Remove all notifications for a chat
Args:
chat_id (UUID): The id of the chat
"""
(
self.db.from_("notifications")
.delete()
.filter("chat_id", "eq", chat_id)
.execute()
).data
def get_notifications_by_chat_id(self, chat_id: UUID) -> list[Notification]:
"""
Get all notifications for a chat
@ -102,9 +115,11 @@ class Notifications(Repository):
Returns:
list[Notification]: The notifications
"""
return (
notifications = (
self.db.from_("notifications")
.select("*")
.filter("chat_id", "eq", chat_id)
.execute()
).data
return [Notification(**notification) for notification in notifications]

View File

@ -0,0 +1,57 @@
from enum import Enum
from typing import List, Union
from uuid import UUID
from models.notifications import Notification
from pydantic import BaseModel
from utils.parse_message_time import (
parse_message_time,
)
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
from repository.notification.get_chat_notifications import (
get_chat_notifications,
)
class ChatItemType(Enum):
MESSAGE = "MESSAGE"
NOTIFICATION = "NOTIFICATION"
class ChatItem(BaseModel):
item_type: ChatItemType
body: Union[GetChatHistoryOutput, Notification]
def merge_chat_history_and_notifications(
chat_history: List[GetChatHistoryOutput], notifications: List[Notification]
) -> List[ChatItem]:
chat_history_and_notifications = chat_history + notifications
chat_history_and_notifications.sort(
key=lambda x: parse_message_time(x.message_time)
if isinstance(x, GetChatHistoryOutput)
else parse_message_time(x.datetime)
)
transformed_data = []
for item in chat_history_and_notifications:
if isinstance(item, GetChatHistoryOutput):
item_type = ChatItemType.MESSAGE
body = item
else:
item_type = ChatItemType.NOTIFICATION
body = item
transformed_item = ChatItem(item_type=item_type, body=body)
transformed_data.append(transformed_item)
return transformed_data
def get_chat_history_with_notifications(
chat_id: UUID,
) -> List[ChatItem]:
chat_history = get_chat_history(str(chat_id))
chat_notifications = get_chat_notifications(chat_id)
return merge_chat_history_and_notifications(chat_history, chat_notifications)

View File

@ -0,0 +1,14 @@
from typing import List
from uuid import UUID
from models.notifications import Notification
from models.settings import get_supabase_db
def get_chat_notifications(chat_id: UUID) -> List[Notification]:
"""
Get notifications by chat_id
"""
supabase_db = get_supabase_db()
return supabase_db.get_notifications_by_chat_id(chat_id)

View File

@ -0,0 +1,11 @@
from uuid import UUID
from models.settings import get_supabase_db
def remove_chat_notifications(chat_id: UUID) -> None:
"""
Remove all notifications for a chat
"""
supabase_db = get_supabase_db()
supabase_db.remove_notifications_by_chat_id(chat_id)

View File

@ -7,6 +7,9 @@ from venv import logger
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from repository.notification.remove_chat_notifications import (
remove_chat_notifications,
)
from llm.openai import OpenAIBrainPicking
from llm.qa_headless import HeadlessQA
from models import (
@ -26,10 +29,13 @@ from repository.chat import (
GetChatHistoryOutput,
create_chat,
get_chat_by_id,
get_chat_history,
get_user_chats,
update_chat,
)
from repository.chat.get_chat_history_with_notifications import (
ChatItem,
get_chat_history_with_notifications,
)
from repository.user_identity import get_user_identity
chat_router = APIRouter()
@ -114,6 +120,8 @@ async def delete_chat(chat_id: UUID):
Delete a specific chat by chat ID.
"""
supabase_db = get_supabase_db()
remove_chat_notifications(chat_id)
delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id)
return {"message": f"{chat_id} has been deleted."}
@ -333,6 +341,6 @@ async def create_stream_question_handler(
)
async def get_chat_history_handler(
chat_id: UUID,
) -> List[GetChatHistoryOutput]:
) -> List[ChatItem]:
# TODO: RBAC with current_user
return get_chat_history(str(chat_id))
return get_chat_history_with_notifications(chat_id)

View File

@ -33,6 +33,7 @@ async def crawl_endpoint(
request: Request,
crawl_website: CrawlWebsite,
brain_id: UUID = Query(..., description="The ID of the brain"),
chat_id: UUID = Query(..., description="The ID of the chat"),
enable_summarization: bool = False,
current_user: UserIdentity = Depends(get_current_user),
):
@ -56,11 +57,15 @@ async def crawl_endpoint(
"type": "error",
}
else:
crawl_notification = add_notification(
CreateNotificationProperties(
action="CRAWL",
crawl_notification = None
if chat_id:
crawl_notification = add_notification(
CreateNotificationProperties(
action="CRAWL",
chat_id=chat_id,
status=NotificationsStatusEnum.Pending,
)
)
)
if not crawl_website.checkGithub():
(
file_path,
@ -92,10 +97,11 @@ async def crawl_endpoint(
brain_id=brain_id,
user_openai_api_key=request.headers.get("Openai-Api-Key", None),
)
update_notification_by_id(
crawl_notification.id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.Done, message=str(message)
),
)
if crawl_notification:
update_notification_by_id(
crawl_notification.id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.Done, message=str(message)
),
)
return message

View File

@ -36,6 +36,7 @@ async def upload_file(
request: Request,
uploadFile: UploadFile,
brain_id: UUID = Query(..., description="The ID of the brain"),
chat_id: UUID = Query(..., description="The ID of the chat"),
enable_summarization: bool = False,
current_user: UserIdentity = Depends(get_current_user),
):
@ -71,11 +72,15 @@ async def upload_file(
"type": "error",
}
else:
upload_notification = add_notification(
CreateNotificationProperties(
action="UPLOAD",
upload_notification = None
if chat_id:
upload_notification = add_notification(
CreateNotificationProperties(
action="UPLOAD",
chat_id=chat_id,
status=NotificationsStatusEnum.Pending,
)
)
)
openai_api_key = request.headers.get("Openai-Api-Key", None)
if openai_api_key is None:
brain_details = get_brain_details(brain_id)
@ -91,11 +96,13 @@ async def upload_file(
brain_id=brain_id,
openai_api_key=openai_api_key,
)
update_notification_by_id(
upload_notification.id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.Done, message=str(message)
),
)
if upload_notification:
update_notification_by_id(
upload_notification.id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.Done, message=str(message)
),
)
return message

View File

@ -0,0 +1,5 @@
from datetime import datetime
def parse_message_time(message_time_str):
return datetime.strptime(message_time_str, "%Y-%m-%dT%H:%M:%S.%f")

View File

@ -1,9 +1,11 @@
/* eslint-disable max-lines */
import axios from "axios";
import { UUID } from "crypto";
import { useParams } from "next/navigation";
import { useCallback, useState } from "react";
import { useTranslation } from "react-i18next";
import { useChatApi } from "@/lib/api/chat/useChatApi";
import { useCrawlApi } from "@/lib/api/crawl/useCrawlApi";
import { useUploadApi } from "@/lib/api/upload/useUploadApi";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
@ -18,6 +20,10 @@ export const useKnowledgeUploader = () => {
const { uploadFile } = useUploadApi();
const { t } = useTranslation(["upload"]);
const { crawlWebsiteUrl } = useCrawlApi();
const { createChat } = useChatApi();
const params = useParams();
const chatId = params?.chatId as UUID | undefined;
const { currentBrainId } = useBrainContext();
const addContent = (content: FeedItemType) => {
@ -28,7 +34,7 @@ export const useKnowledgeUploader = () => {
};
const crawlWebsiteHandler = useCallback(
async (url: string, brainId: UUID) => {
async (url: string, brainId: UUID, chat_id: UUID) => {
// Configure parameters
const config = {
url: url,
@ -42,6 +48,7 @@ export const useKnowledgeUploader = () => {
await crawlWebsiteUrl({
brainId,
config,
chat_id,
});
} catch (error: unknown) {
publish({
@ -56,13 +63,14 @@ export const useKnowledgeUploader = () => {
);
const uploadFileHandler = useCallback(
async (file: File, brainId: UUID) => {
async (file: File, brainId: UUID, chat_id: UUID) => {
const formData = new FormData();
formData.append("uploadFile", file);
try {
await uploadFile({
brainId: brainId,
brainId,
formData,
chat_id,
});
} catch (e: unknown) {
if (axios.isAxiosError(e) && e.response?.status === 403) {
@ -104,12 +112,14 @@ export const useKnowledgeUploader = () => {
return;
}
try {
const currentChatId = chatId ?? (await createChat("New Chat")).chat_id;
const uploadPromises = files.map((file) =>
uploadFileHandler(file, currentBrainId)
uploadFileHandler(file, currentBrainId, currentChatId)
);
const crawlPromises = urls.map((url) =>
crawlWebsiteHandler(url, currentBrainId)
crawlWebsiteHandler(url, currentBrainId, currentChatId)
);
await Promise.all([...uploadPromises, ...crawlPromises]);

View File

@ -4,6 +4,8 @@ import { useEffect } from "react";
import { useChatApi } from "@/lib/api/chat/useChatApi";
import { useChatContext } from "@/lib/context";
import { getMessagesFromChatHistory } from "../utils/getMessagesFromChatHistory";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useSelectedChatPage = () => {
const { setHistory } = useChatContext();
@ -23,7 +25,7 @@ export const useSelectedChatPage = () => {
const chatHistory = await getHistory(chatId);
if (chatHistory.length > 0) {
setHistory(chatHistory);
setHistory(getMessagesFromChatHistory(chatHistory));
}
};
void fetchHistory();

View File

@ -18,6 +18,22 @@ export type ChatHistory = {
brain_name?: string;
};
type HistoryItemType = "MESSAGE" | "NOTIFICATION";
type Notification = {
id: string;
datetime: string;
chat_id?: string | null;
message?: string | null;
action: string;
status: string;
};
export type ChatItem = {
item_type: HistoryItemType;
body: ChatHistory | Notification;
};
export type ChatEntity = {
chat_id: UUID;
user_id: string;

View File

@ -0,0 +1,11 @@
import { ChatHistory, ChatItem } from "../types";
export const getMessagesFromChatHistory = (
chatHistory: ChatItem[]
): ChatHistory[] => {
const messages = chatHistory
.filter((item) => item.item_type === "MESSAGE")
.map((item) => item.body as ChatHistory);
return messages;
};

View File

@ -3,6 +3,7 @@ import { AxiosInstance } from "axios";
import {
ChatEntity,
ChatHistory,
ChatItem,
ChatQuestion,
} from "@/app/chat/[chatId]/types";
@ -55,8 +56,8 @@ export const addQuestion = async (
export const getHistory = async (
chatId: string,
axiosInstance: AxiosInstance
): Promise<ChatHistory[]> =>
(await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)).data;
): Promise<ChatItem[]> =>
(await axiosInstance.get<ChatItem[]>(`/chat/${chatId}/history`)).data;
export type ChatUpdatableProperties = {
chat_name?: string;

View File

@ -24,6 +24,7 @@ describe("useCrawlApi", () => {
} = renderHook(() => useCrawlApi());
const crawlInputProps: CrawlInputProps = {
brainId: "e7001ccd-6d90-4eab-8c50-2f23d39441e4",
chat_id: "e7001ccd-6d90-4eab-8c50-2f23d39441es",
config: {
url: "https://en.wikipedia.org/wiki/Mali",
js: false,
@ -36,7 +37,9 @@ describe("useCrawlApi", () => {
expect(axiosPostMock).toHaveBeenCalledTimes(1);
expect(axiosPostMock).toHaveBeenCalledWith(
`/crawl?brain_id=${crawlInputProps.brainId}`,
`/crawl?brain_id=${crawlInputProps.brainId}&chat_id=${
crawlInputProps.chat_id ?? ""
}`,
crawlInputProps.config
);
});

View File

@ -5,6 +5,7 @@ import { ToastData } from "@/lib/components/ui/Toast/domain/types";
export type CrawlInputProps = {
brainId: UUID;
chat_id?: UUID;
config: {
url: string;
js: boolean;
@ -22,4 +23,7 @@ export const crawlWebsiteUrl = async (
props: CrawlInputProps,
axiosInstance: AxiosInstance
): Promise<CrawlResponse> =>
axiosInstance.post(`/crawl?brain_id=${props.brainId}`, props.config);
axiosInstance.post(
`/crawl?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`,
props.config
);

View File

@ -10,10 +10,14 @@ export type UploadResponse = {
export type UploadInputProps = {
brainId: UUID;
formData: FormData;
chat_id?: UUID;
};
export const uploadFile = async (
props: UploadInputProps,
axiosInstance: AxiosInstance
): Promise<UploadResponse> =>
axiosInstance.post(`/upload?brain_id=${props.brainId}`, props.formData);
axiosInstance.post(
`/upload?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`,
props.formData
);