2023-05-21 02:20:55 +03:00
import os
import shutil
2023-06-01 17:01:27 +03:00
import time
2023-05-30 14:02:48 +03:00
from tempfile import SpooledTemporaryFile
2023-05-21 02:20:55 +03:00
2023-05-30 14:02:48 +03:00
import pypandoc
from auth_bearer import JWTBearer
from crawl . crawler import CrawlWebsite
2023-06-01 23:51:39 +03:00
from fastapi import Depends , FastAPI , UploadFile
2023-05-30 14:02:48 +03:00
from fastapi . middleware . cors import CORSMiddleware
2023-05-22 09:39:55 +03:00
from llm . qa import get_qa_llm
2023-05-30 14:02:48 +03:00
from llm . summarization import llm_evaluate_summaries
from logger import get_logger
from parsers . audio import process_audio
2023-05-21 02:20:55 +03:00
from parsers . common import file_already_exists
from parsers . csv import process_csv
from parsers . docx import process_docx
2023-05-30 14:02:48 +03:00
from parsers . epub import process_epub
from parsers . html import process_html
2023-05-21 02:20:55 +03:00
from parsers . markdown import process_markdown
2023-05-30 14:02:48 +03:00
from parsers . notebook import process_ipnyb
from parsers . odt import process_odt
from parsers . pdf import process_pdf
2023-05-21 02:20:55 +03:00
from parsers . powerpoint import process_powerpoint
2023-05-30 14:02:48 +03:00
from parsers . txt import process_txt
from pydantic import BaseModel
2023-05-31 14:51:23 +03:00
from supabase import Client
2023-06-01 23:51:39 +03:00
from utils import ( ChatMessage , CommonsDep , convert_bytes , create_user ,
get_file_size , similarity_search , update_user_request_count )
2023-05-31 14:51:23 +03:00
2023-05-22 09:39:55 +03:00
logger = get_logger ( __name__ )
2023-05-21 02:20:55 +03:00
app = FastAPI ( )
origins = [
" http://localhost " ,
" http://localhost:3000 " ,
2023-06-03 20:07:36 +03:00
" https://quivr.app " ,
" https://www.quivr.app " ,
2023-05-21 02:20:55 +03:00
]
app . add_middleware (
CORSMiddleware ,
allow_origins = origins ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
2023-05-21 09:15:31 +03:00
2023-06-01 17:01:27 +03:00
2023-05-21 09:15:31 +03:00
@app.on_event ( " startup " )
async def startup_event ( ) :
pypandoc . download_pandoc ( )
2023-05-21 02:20:55 +03:00
file_processors = {
" .txt " : process_txt ,
" .csv " : process_csv ,
" .md " : process_markdown ,
" .markdown " : process_markdown ,
" .m4a " : process_audio ,
" .mp3 " : process_audio ,
" .webm " : process_audio ,
" .mp4 " : process_audio ,
" .mpga " : process_audio ,
" .wav " : process_audio ,
" .mpeg " : process_audio ,
" .pdf " : process_pdf ,
" .html " : process_html ,
" .pptx " : process_powerpoint ,
2023-05-21 09:15:31 +03:00
" .docx " : process_docx ,
2023-05-30 14:02:48 +03:00
" .odt " : process_odt ,
2023-05-21 09:15:31 +03:00
" .epub " : process_epub ,
2023-05-21 22:12:41 +03:00
" .ipynb " : process_ipnyb ,
2023-05-21 02:20:55 +03:00
}
2023-05-21 09:32:22 +03:00
2023-05-31 14:51:23 +03:00
class User ( BaseModel ) :
email : str
async def filter_file ( file : UploadFile , enable_summarization : bool , supabase_client : Client , user : User ) :
if await file_already_exists ( supabase_client , file , user ) :
2023-05-21 02:20:55 +03:00
return { " message " : f " 🤔 { file . filename } already exists. " , " type " : " warning " }
2023-06-01 23:51:39 +03:00
elif file . file . _file . tell ( ) < 1 :
2023-05-21 02:20:55 +03:00
return { " message " : f " ❌ { file . filename } is empty. " , " type " : " error " }
else :
2023-05-26 00:58:38 +03:00
file_extension = os . path . splitext ( file . filename ) [ - 1 ] . lower ( ) # Convert file extension to lowercase
2023-05-21 02:20:55 +03:00
if file_extension in file_processors :
2023-05-31 14:51:23 +03:00
await file_processors [ file_extension ] ( file , enable_summarization , user )
2023-05-21 02:20:55 +03:00
return { " message " : f " ✅ { file . filename } has been uploaded. " , " type " : " success " }
else :
return { " message " : f " ❌ { file . filename } is not supported. " , " type " : " error " }
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.post ( " /upload " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def upload_file ( commons : CommonsDep , file : UploadFile , enable_summarization : bool = False , credentials : dict = Depends ( JWTBearer ( ) ) ) :
2023-06-01 23:51:39 +03:00
max_brain_size = os . getenv ( " MAX_BRAIN_SIZE " )
2023-05-31 14:51:23 +03:00
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
2023-06-01 23:51:39 +03:00
user_vectors_response = commons [ ' supabase ' ] . table ( " vectors " ) . select (
" name:metadata->>file_name, size:metadata->>file_size " , count = " exact " ) \
. filter ( " user_id " , " eq " , user . email ) \
. execute ( )
documents = user_vectors_response . data # Access the data from the response
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
user_unique_vectors = [ dict ( t ) for t in set ( tuple ( d . items ( ) ) for d in documents ) ]
2023-05-24 23:21:22 +03:00
2023-06-01 23:51:39 +03:00
current_brain_size = sum ( float ( doc [ ' size ' ] ) for doc in user_unique_vectors )
file_size = get_file_size ( file )
remaining_free_space = float ( max_brain_size ) - ( current_brain_size )
if remaining_free_space - file_size < 0 :
message = { " message " : f " ❌ User ' s brain will exceed maximum capacity with this upload. Maximum file allowed is : { convert_bytes ( remaining_free_space ) } " , " type " : " error " }
else :
message = await filter_file ( file , enable_summarization , commons [ ' supabase ' ] , user )
2023-05-21 02:20:55 +03:00
return message
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.post ( " /chat/ " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def chat_endpoint ( commons : CommonsDep , chat_message : ChatMessage , credentials : dict = Depends ( JWTBearer ( ) ) ) :
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
2023-06-01 23:51:39 +03:00
date = time . strftime ( " % Y % m %d " )
max_requests_number = os . getenv ( " MAX_REQUESTS_NUMBER " )
response = commons [ ' supabase ' ] . from_ ( ' users ' ) . select (
' * ' ) . filter ( " user_id " , " eq " , user . email ) . filter ( " date " , " eq " , date ) . execute ( )
userItem = next ( iter ( response . data or [ ] ) , { " requests_count " : 0 } )
old_request_count = userItem [ ' requests_count ' ]
2023-05-21 02:20:55 +03:00
history = chat_message . history
history . append ( ( " user " , chat_message . question ) )
2023-05-22 09:39:55 +03:00
2023-06-01 23:51:39 +03:00
qa = get_qa_llm ( chat_message , user . email )
if old_request_count == 0 :
create_user ( user_id = user . email , date = date )
elif old_request_count < float ( max_requests_number ) :
update_user_request_count ( user_id = user . email , date = date , requests_count = old_request_count + 1 )
else :
history . append ( ( ' assistant ' , " You have reached your requests limit " ) )
return { " history " : history }
2023-05-22 09:39:55 +03:00
if chat_message . use_summarization :
# 1. get summaries from the vector store based on question
summaries = similarity_search (
chat_message . question , table = ' match_summaries ' )
# 2. evaluate summaries against the question
evaluations = llm_evaluate_summaries (
chat_message . question , summaries , chat_message . model )
# 3. pull in the top documents from summaries
logger . info ( ' Evaluations: %s ' , evaluations )
if evaluations :
2023-05-31 14:51:23 +03:00
reponse = commons [ ' supabase ' ] . from_ ( ' vectors ' ) . select (
2023-05-22 09:39:55 +03:00
' * ' ) . in_ ( ' id ' , values = [ e [ ' document_id ' ] for e in evaluations ] ) . execute ( )
# 4. use top docs as additional context
additional_context = ' --- \n Additional Context= {} ' . format (
' --- \n ' . join ( data [ ' content ' ] for data in reponse . data )
) + ' \n '
model_response = qa (
{ " question " : additional_context + chat_message . question } )
else :
model_response = qa ( { " question " : chat_message . question } )
2023-05-21 02:20:55 +03:00
history . append ( ( " assistant " , model_response [ " answer " ] ) )
return { " history " : history }
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.post ( " /crawl/ " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def crawl_endpoint ( commons : CommonsDep , crawl_website : CrawlWebsite , enable_summarization : bool = False , credentials : dict = Depends ( JWTBearer ( ) ) ) :
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
2023-05-21 02:20:55 +03:00
file_path , file_name = crawl_website . process ( )
# Create a SpooledTemporaryFile from the file_path
spooled_file = SpooledTemporaryFile ( )
with open ( file_path , ' rb ' ) as f :
shutil . copyfileobj ( f , spooled_file )
# Pass the SpooledTemporaryFile to UploadFile
file = UploadFile ( file = spooled_file , filename = file_name )
2023-05-31 14:51:23 +03:00
message = await filter_file ( file , enable_summarization , commons [ ' supabase ' ] , user = user )
2023-05-22 09:56:11 +03:00
return message
2023-05-21 02:20:55 +03:00
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.get ( " /explore " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def explore_endpoint ( commons : CommonsDep , credentials : dict = Depends ( JWTBearer ( ) ) ) :
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
response = commons [ ' supabase ' ] . table ( " vectors " ) . select (
" name:metadata->>file_name, size:metadata->>file_size " , count = " exact " ) . filter ( " user_id " , " eq " , user . email ) . execute ( )
2023-05-21 02:20:55 +03:00
documents = response . data # Access the data from the response
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
unique_data = [ dict ( t ) for t in set ( tuple ( d . items ( ) ) for d in documents ) ]
# Sort the list of documents by size in decreasing order
unique_data . sort ( key = lambda x : int ( x [ ' size ' ] ) , reverse = True )
return { " documents " : unique_data }
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.delete ( " /explore/ {file_name} " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def delete_endpoint ( commons : CommonsDep , file_name : str , credentials : dict = Depends ( JWTBearer ( ) ) ) :
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
2023-05-22 09:39:55 +03:00
# Cascade delete the summary from the database first, because it has a foreign key constraint
commons [ ' supabase ' ] . table ( " summaries " ) . delete ( ) . match (
{ " metadata->>file_name " : file_name } ) . execute ( )
2023-05-31 14:51:23 +03:00
commons [ ' supabase ' ] . table ( " vectors " ) . delete ( ) . match (
{ " metadata->>file_name " : file_name , " user_id " : user . email } ) . execute ( )
return { " message " : f " { file_name } of user { user . email } has been deleted. " }
2023-05-21 02:20:55 +03:00
2023-05-21 09:32:22 +03:00
2023-05-24 23:21:22 +03:00
@app.get ( " /explore/ {file_name} " , dependencies = [ Depends ( JWTBearer ( ) ) ] )
2023-05-31 14:51:23 +03:00
async def download_endpoint ( commons : CommonsDep , file_name : str , credentials : dict = Depends ( JWTBearer ( ) ) ) :
user = User ( email = credentials . get ( ' email ' , ' none ' ) )
response = commons [ ' supabase ' ] . table ( " vectors " ) . select (
" metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url " ) . match ( { " metadata->>file_name " : file_name , " user_id " : user . email } ) . execute ( )
2023-05-21 02:20:55 +03:00
documents = response . data
2023-05-21 09:32:22 +03:00
# Returns all documents with the same file name
2023-05-21 02:20:55 +03:00
return { " documents " : documents }
@app.get ( " / " )
async def root ( ) :
return { " message " : " Hello World " }