feat: 🎸 sources (#2092)

added metadata object a bit bigger

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-01-25 18:56:54 -08:00 committed by GitHub
parent 7c9fdd28ed
commit 6c5496f797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 41 additions and 29 deletions

View File

@ -14,7 +14,7 @@ from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger from logger import get_logger
from models import BrainSettings from models import BrainSettings
from modules.brain.service.brain_service import BrainService from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatQuestion from modules.chat.dto.chats import ChatQuestion, Sources
from modules.chat.dto.inputs import CreateChatHistory from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService from modules.chat.service.chat_service import ChatService
@ -296,21 +296,31 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
logger.error("Error during streaming tokens: %s", e) logger.error("Error during streaming tokens: %s", e)
try: try:
result = await run result = await run
source_documents = result.get("source_documents", [])
# Deduplicate source documents
source_documents = list(
{doc.metadata["file_name"]: doc for doc in source_documents}.values()
)
sources_list: List[Sources] = []
source_documents = result.get("source_documents", [])
if source_documents: if source_documents:
# Formatting the source documents using Markdown without new lines for each source serialized_sources_list = []
sources_list = [ for doc in source_documents:
f"[{doc.metadata['file_name']}])" for doc in source_documents sources_list.append(
] Sources(
**{
"name": doc.metadata["url"]
if "url" in doc.metadata
else doc.metadata["file_name"],
"type": "url" if "url" in doc.metadata else "file",
"source_url": doc.metadata["url"]
if "url" in doc.metadata
else "",
}
)
)
# Create metadata if it doesn't exist # Create metadata if it doesn't exist
if not streamed_chat_history.metadata: if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {} streamed_chat_history.metadata = {}
streamed_chat_history.metadata["sources"] = sources_list # Serialize the sources list
serialized_sources_list = [source.dict() for source in sources_list]
streamed_chat_history.metadata["sources"] = serialized_sources_list
yield f"data: {json.dumps(streamed_chat_history.dict())}" yield f"data: {json.dumps(streamed_chat_history.dict())}"
else: else:
logger.info( logger.info(

View File

@ -28,6 +28,18 @@ class ChatQuestion(BaseModel):
prompt_id: Optional[UUID] prompt_id: Optional[UUID]
class Sources(BaseModel):
name: str
source_url: str
type: str
class Config:
json_encoders = {
**BaseModel.Config.json_encoders,
UUID: lambda v: str(v),
}
class ChatItemType(Enum): class ChatItemType(Enum):
MESSAGE = "MESSAGE" MESSAGE = "MESSAGE"
NOTIFICATION = "NOTIFICATION" NOTIFICATION = "NOTIFICATION"

View File

@ -1,12 +1,14 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from fastapi import HTTPException from logger import get_logger
from modules.notification.dto.outputs import DeleteNotificationResponse from modules.notification.dto.outputs import DeleteNotificationResponse
from modules.notification.entity.notification import Notification from modules.notification.entity.notification import Notification
from modules.notification.repository.notifications_interface import ( from modules.notification.repository.notifications_interface import (
NotificationInterface, NotificationInterface,
) )
logger = get_logger(__name__)
class Notifications(NotificationInterface): class Notifications(NotificationInterface):
def __init__(self, supabase_client): def __init__(self, supabase_client):
@ -35,7 +37,8 @@ class Notifications(NotificationInterface):
).data ).data
if response == []: if response == []:
raise HTTPException(404, "Notification not found") logger.info(f"Notification with id {notification_id} not found")
return None
return Notification(**response[0]) return Notification(**response[0])
@ -57,7 +60,8 @@ class Notifications(NotificationInterface):
) )
if response == []: if response == []:
raise HTTPException(404, "Notification not found") logger.info(f"Notification with id {notification_id} not found")
return None
return DeleteNotificationResponse( return DeleteNotificationResponse(
status="deleted", notification_id=notification_id status="deleted", notification_id=notification_id

View File

@ -8,8 +8,6 @@ from middlewares.auth import AuthBearer, get_current_user
from models import UserUsage from models import UserUsage
from modules.knowledge.dto.inputs import CreateKnowledgeProperties from modules.knowledge.dto.inputs import CreateKnowledgeProperties
from modules.knowledge.service.knowledge_service import KnowledgeService from modules.knowledge.service.knowledge_service import KnowledgeService
from modules.notification.dto.inputs import CreateNotificationProperties
from modules.notification.entity.notification import NotificationsStatusEnum
from modules.notification.service.notification_service import NotificationService from modules.notification.service.notification_service import NotificationService
from modules.user.entity.user_identity import UserIdentity from modules.user.entity.user_identity import UserIdentity
from packages.files.crawl.crawler import CrawlWebsite from packages.files.crawl.crawler import CrawlWebsite
@ -56,16 +54,6 @@ async def crawl_endpoint(
"type": "error", "type": "error",
} }
else: else:
crawl_notification_id = None
if chat_id:
crawl_notification_id = notification_service.add_notification(
CreateNotificationProperties(
action="CRAWL",
chat_id=chat_id,
status=NotificationsStatusEnum.Pending,
).id
)
knowledge_to_add = CreateKnowledgeProperties( knowledge_to_add = CreateKnowledgeProperties(
brain_id=brain_id, brain_id=brain_id,
url=crawl_website.url, url=crawl_website.url,
@ -78,7 +66,7 @@ async def crawl_endpoint(
process_crawl_and_notify.delay( process_crawl_and_notify.delay(
crawl_website_url=crawl_website.url, crawl_website_url=crawl_website.url,
brain_id=brain_id, brain_id=brain_id,
notification_id=crawl_notification_id, notification_id=None,
) )
return {"message": "Crawl processing has started."} return {"message": "Crawl processing has started."}

View File

@ -4,7 +4,6 @@ import { CopyButton } from "./components/CopyButton";
import { MessageContent } from "./components/MessageContent"; import { MessageContent } from "./components/MessageContent";
import { QuestionBrain } from "./components/QuestionBrain"; import { QuestionBrain } from "./components/QuestionBrain";
import { QuestionPrompt } from "./components/QuestionPrompt"; import { QuestionPrompt } from "./components/QuestionPrompt";
import { SourcesButton } from "./components/SourcesButton";
import { useMessageRow } from "./hooks/useMessageRow"; import { useMessageRow } from "./hooks/useMessageRow";
type MessageRowProps = { type MessageRowProps = {
@ -60,7 +59,6 @@ export const MessageRow = React.forwardRef(
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
{!isUserSpeaker && ( {!isUserSpeaker && (
<> <>
{hasSources && <SourcesButton sources={sourcesContent} />}
<CopyButton handleCopy={handleCopy} isCopied={isCopied} /> <CopyButton handleCopy={handleCopy} isCopied={isCopied} />
</> </>
)} )}