quivr/backend/api.py

181 lines
6.6 KiB
Python

from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, UploadFile, File, HTTPException, Header, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import os
from pydantic import BaseModel
from typing import List, Tuple, Annotated
from supabase import create_client, Client
from tempfile import SpooledTemporaryFile
from auth_bearer import JWTBearer
import shutil
import pypandoc
from llm.summarization import llm_evaluate_summaries
from utils import similarity_search
from utils import CommonsDep
from utils import ChatMessage
from llm.qa import get_qa_llm
from parsers.common import file_already_exists
from parsers.txt import process_txt
from parsers.csv import process_csv
from parsers.docx import process_docx
from parsers.pdf import process_pdf
from parsers.notebook import process_ipnyb
from parsers.markdown import process_markdown
from parsers.powerpoint import process_powerpoint
from parsers.html import process_html
from parsers.epub import process_epub
from parsers.audio import process_audio
from crawl.crawler import CrawlWebsite
from logger import get_logger
logger = get_logger(__name__)
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:3000",
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
pypandoc.download_pandoc()
file_processors = {
".txt": process_txt,
".csv": process_csv,
".md": process_markdown,
".markdown": process_markdown,
".m4a": process_audio,
".mp3": process_audio,
".webm": process_audio,
".mp4": process_audio,
".mpga": process_audio,
".wav": process_audio,
".mpeg": process_audio,
".pdf": process_pdf,
".html": process_html,
".pptx": process_powerpoint,
".docx": process_docx,
".epub": process_epub,
".ipynb": process_ipnyb,
}
async def filter_file(file: UploadFile, enable_summarization: bool, supabase_client: Client):
if await file_already_exists(supabase_client, file):
return {"message": f"🤔 {file.filename} already exists.", "type": "warning"}
elif file.file._file.tell() < 1:
return {"message": f"{file.filename} is empty.", "type": "error"}
else:
file_extension = os.path.splitext(file.filename)[-1].lower() # Convert file extension to lowercase
if file_extension in file_processors:
await file_processors[file_extension](file, enable_summarization)
return {"message": f"{file.filename} has been uploaded.", "type": "success"}
else:
return {"message": f"{file.filename} is not supported.", "type": "error"}
@app.post("/upload", dependencies=[Depends(JWTBearer())])
async def upload_file(commons: CommonsDep, file: UploadFile, enable_summarization: bool = False):
message = await filter_file(file, enable_summarization, commons['supabase'])
return message
@app.post("/chat/", dependencies=[Depends(JWTBearer())])
async def chat_endpoint(commons: CommonsDep, chat_message: ChatMessage):
history = chat_message.history
qa = get_qa_llm(chat_message)
history.append(("user", chat_message.question))
if chat_message.use_summarization:
# 1. get summaries from the vector store based on question
summaries = similarity_search(
chat_message.question, table='match_summaries')
# 2. evaluate summaries against the question
evaluations = llm_evaluate_summaries(
chat_message.question, summaries, chat_message.model)
# 3. pull in the top documents from summaries
logger.info('Evaluations: %s', evaluations)
if evaluations:
reponse = commons['supabase'].from_('documents').select(
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
# 4. use top docs as additional context
additional_context = '---\nAdditional Context={}'.format(
'---\n'.join(data['content'] for data in reponse.data)
) + '\n'
model_response = qa(
{"question": additional_context + chat_message.question})
else:
model_response = qa({"question": chat_message.question})
history.append(("assistant", model_response["answer"]))
return {"history": history}
@app.post("/crawl/", dependencies=[Depends(JWTBearer())])
async def crawl_endpoint(commons: CommonsDep, crawl_website: CrawlWebsite, enable_summarization: bool = False):
file_path, file_name = crawl_website.process()
# Create a SpooledTemporaryFile from the file_path
spooled_file = SpooledTemporaryFile()
with open(file_path, 'rb') as f:
shutil.copyfileobj(f, spooled_file)
# Pass the SpooledTemporaryFile to UploadFile
file = UploadFile(file=spooled_file, filename=file_name)
message = await filter_file(file, enable_summarization, commons['supabase'])
return message
@app.get("/explore", dependencies=[Depends(JWTBearer())])
async def explore_endpoint(commons: CommonsDep):
response = commons['supabase'].table("documents").select(
"name:metadata->>file_name, size:metadata->>file_size", count="exact").execute()
documents = response.data # Access the data from the response
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
# Sort the list of documents by size in decreasing order
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
return {"documents": unique_data}
@app.delete("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
async def delete_endpoint(commons: CommonsDep, file_name: str):
# Cascade delete the summary from the database first, because it has a foreign key constraint
commons['supabase'].table("summaries").delete().match(
{"metadata->>file_name": file_name}).execute()
commons['supabase'].table("documents").delete().match(
{"metadata->>file_name": file_name}).execute()
return {"message": f"{file_name} has been deleted."}
@app.get("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
async def download_endpoint(commons: CommonsDep, file_name: str):
response = commons['supabase'].table("documents").select(
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url").match({"metadata->>file_name": file_name}).execute()
documents = response.data
# Returns all documents with the same file name
return {"documents": documents}
@app.get("/")
async def root():
return {"message": "Hello World"}