mirror of
https://github.com/StanGirard/quivr.git
synced 2024-12-24 03:41:56 +03:00
Feat: backend refactor (#306)
* fix: edge cases on migration scripts * chore: remove unused deps. * refactor: user_routes * refactor: chat_routes * refactor: upload_routes * refactor: explore_routes * refactor: crawl_routes * chore(refactor): get current user * refactor: more dead dependencies * bug: wrap email in credentials dict. --------- Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
parent
67530c13f2
commit
ec29f30f32
@ -1,9 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi import HTTPException, Request, Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from models.users import User
|
||||
from .auth_handler import decode_access_token
|
||||
|
||||
|
||||
@ -27,4 +27,8 @@ class JWTBearer(HTTPBearer):
|
||||
|
||||
def verify_jwt(self, jwtoken: str):
|
||||
payload = decode_access_token(jwtoken)
|
||||
return payload
|
||||
return payload
|
||||
|
||||
|
||||
def get_current_user(credentials: dict = Depends(JWTBearer())) -> User:
|
||||
return User(email=credentials.get('email', 'none'))
|
||||
|
@ -1,36 +1,28 @@
|
||||
import os
|
||||
|
||||
import pypandoc
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from fastapi import FastAPI
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
from routes.chat_routes import chat_router
|
||||
from routes.crawl_routes import crawl_router
|
||||
from routes.explore_routes import explore_router
|
||||
from routes.misc_routes import misc_router
|
||||
from routes.upload_routes import upload_router
|
||||
from routes.user_routes import user_router
|
||||
from utils.vectors import (CommonsDep, create_user, similarity_search,
|
||||
update_user_request_count)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
add_cors_middleware(app)
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
pypandoc.download_pandoc()
|
||||
|
||||
|
||||
app.include_router(chat_router)
|
||||
app.include_router(crawl_router)
|
||||
app.include_router(explore_router)
|
||||
|
@ -1,9 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
@ -14,107 +13,89 @@ from utils.vectors import (CommonsDep, create_chat, create_user,
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
def get_user_chats(commons, user_id):
|
||||
response = commons['supabase'].from_('chats').select('chatId:chat_id, chatName:chat_name').filter("user_id", "eq", user_id).execute()
|
||||
return response.data
|
||||
|
||||
def get_chat_details(commons, chat_id):
|
||||
response = commons['supabase'].from_('chats').select('*').filter("chat_id", "eq", chat_id).execute()
|
||||
return response.data
|
||||
|
||||
def delete_chat_from_db(commons, chat_id):
|
||||
commons['supabase'].table("chats").delete().match({"chat_id": chat_id}).execute()
|
||||
|
||||
def fetch_user_stats(commons, user, date):
|
||||
response = commons['supabase'].from_('users').select('*').filter("email", "eq", user.email).filter("date", "eq", date).execute()
|
||||
userItem = next(iter(response.data or []), {"requests_count": 0})
|
||||
return userItem
|
||||
|
||||
# get all chats
|
||||
@chat_router.get("/chat", dependencies=[Depends(JWTBearer())])
|
||||
async def get_chats(commons: CommonsDep, credentials: dict = Depends(JWTBearer())):
|
||||
async def get_chats(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
date = time.strftime("%Y%m%d")
|
||||
user_id = fetch_user_id_from_credentials(commons,date, credentials)
|
||||
|
||||
# Fetch all chats for the user
|
||||
response = commons['supabase'].from_('chats').select('chatId:chat_id, chatName:chat_name').filter("user_id", "eq", user_id).execute()
|
||||
chats = response.data
|
||||
# TODO: Only get the chat name instead of the history
|
||||
user_id = fetch_user_id_from_credentials(commons, date, {"email": current_user.email})
|
||||
chats = get_user_chats(commons, user_id)
|
||||
return {"chats": chats}
|
||||
|
||||
# get one chat
|
||||
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
async def get_chats(commons: CommonsDep, chat_id: UUID):
|
||||
|
||||
# Fetch all chats for the user
|
||||
response = commons['supabase'].from_('chats').select('*').filter("chat_id", "eq", chat_id).execute()
|
||||
chats = response.data
|
||||
|
||||
print("/chat/{chat_id}",chats)
|
||||
return {"chatId": chat_id, "history": chats[0]['history']}
|
||||
chats = get_chat_details(commons, chat_id)
|
||||
if len(chats) > 0:
|
||||
return {"chatId": chat_id, "history": chats[0]['history']}
|
||||
else:
|
||||
return {"error": "Chat not found"}
|
||||
|
||||
# delete one chat
|
||||
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
async def delete_chat(commons: CommonsDep,chat_id: UUID):
|
||||
commons['supabase'].table("chats").delete().match(
|
||||
{"chat_id": chat_id}).execute()
|
||||
|
||||
async def delete_chat(commons: CommonsDep, chat_id: UUID):
|
||||
delete_chat_from_db(commons, chat_id)
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
|
||||
# helper method for update and create chat
|
||||
def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False):
|
||||
date = time.strftime("%Y%m%d")
|
||||
user_id = fetch_user_id_from_credentials(commons, date, {"email": email})
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
user_openai_api_key = request.headers.get('Openai-Api-Key')
|
||||
|
||||
userItem = fetch_user_stats(commons, User(email=email), date)
|
||||
old_request_count = userItem['requests_count']
|
||||
|
||||
history = chat_message.history
|
||||
history.append(("user", chat_message.question))
|
||||
|
||||
|
||||
if old_request_count == 0:
|
||||
create_user(email= email, date=date)
|
||||
else:
|
||||
update_user_request_count(email=email, date=date, requests_count=old_request_count + 1)
|
||||
if user_openai_api_key is None and old_request_count >= float(max_requests_number):
|
||||
history.append(('assistant', "You have reached your requests limit"))
|
||||
update_chat(chat_id=chat_id, history=history)
|
||||
return {"history": history}
|
||||
|
||||
|
||||
|
||||
answer = get_answer(commons, chat_message, email, user_openai_api_key)
|
||||
history.append(("assistant", answer))
|
||||
|
||||
if is_new_chat:
|
||||
chat_name = get_chat_name_from_first_question(chat_message)
|
||||
new_chat = create_chat(user_id, history, chat_name)
|
||||
chat_id = new_chat.data[0]['chat_id']
|
||||
else:
|
||||
update_chat(chat_id=chat_id, history=history)
|
||||
|
||||
return {"history": history, "chatId": chat_id}
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
async def chat_endpoint(request: Request,commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, credentials: dict = Depends(JWTBearer())):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
user_openai_api_key = request.headers.get('Openai-Api-Key')
|
||||
|
||||
response = commons['supabase'].from_('users').select(
|
||||
'*').filter("email", "eq", user.email).filter("date", "eq", date).execute()
|
||||
|
||||
|
||||
userItem = next(iter(response.data or []), {"requests_count": 0})
|
||||
old_request_count = userItem['requests_count']
|
||||
|
||||
history = chat_message.history
|
||||
history.append(("user", chat_message.question))
|
||||
|
||||
if old_request_count == 0:
|
||||
create_user(email= user.email, date=date)
|
||||
elif old_request_count < float(max_requests_number) :
|
||||
update_user_request_count(email=user.email, date=date, requests_count= old_request_count+1)
|
||||
else:
|
||||
history.append(('assistant', "You have reached your requests limit"))
|
||||
update_chat(chat_id=chat_id, history=history)
|
||||
return {"history": history }
|
||||
|
||||
answer = get_answer(commons, chat_message, user.email,user_openai_api_key)
|
||||
history.append(("assistant", answer))
|
||||
update_chat(chat_id=chat_id, history=history)
|
||||
|
||||
return {"history": history, "chatId": chat_id}
|
||||
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
|
||||
|
||||
# create new chat
|
||||
@chat_router.post("/chat", dependencies=[Depends(JWTBearer())])
|
||||
async def chat_endpoint(request: Request,commons: CommonsDep, chat_message: ChatMessage, credentials: dict = Depends(JWTBearer())):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
date = time.strftime("%Y%m%d")
|
||||
|
||||
user_id = fetch_user_id_from_credentials(commons, date,credentials)
|
||||
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
user_openai_api_key = request.headers.get('Openai-Api-Key')
|
||||
|
||||
response = commons['supabase'].from_('users').select(
|
||||
'*').filter("email", "eq", user.email).filter("date", "eq", date).execute()
|
||||
|
||||
|
||||
userItem = next(iter(response.data or []), {"requests_count": 0})
|
||||
old_request_count = userItem['requests_count']
|
||||
|
||||
history = chat_message.history
|
||||
history.append(("user", chat_message.question))
|
||||
|
||||
chat_name = get_chat_name_from_first_question(chat_message)
|
||||
print('chat_name',chat_name)
|
||||
if user_openai_api_key is None:
|
||||
if old_request_count == 0:
|
||||
create_user(email= user.email, date=date)
|
||||
elif old_request_count < float(max_requests_number) :
|
||||
update_user_request_count(email=user.email, date=date, requests_count= old_request_count+1)
|
||||
else:
|
||||
history.append(('assistant', "You have reached your requests limit"))
|
||||
new_chat = create_chat(user_id, history, chat_name)
|
||||
return {"history": history, "chatId": new_chat.data[0]['chat_id'] }
|
||||
|
||||
answer = get_answer(commons, chat_message, user.email, user_openai_api_key)
|
||||
history.append(("assistant", answer))
|
||||
new_chat = create_chat(user_id, history, chat_name)
|
||||
|
||||
return {"history": history, "chatId": new_chat.data[0]['chat_id'], "chatName":new_chat.data[0]['chat_name'] }
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
return chat_handler(request, commons, None, chat_message, current_user.email, is_new_chat=True)
|
||||
|
@ -2,10 +2,9 @@ import os
|
||||
import shutil
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from crawl.crawler import CrawlWebsite
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from models.users import User
|
||||
from parsers.github import process_github
|
||||
from utils.file import convert_bytes
|
||||
@ -14,13 +13,7 @@ from utils.vectors import CommonsDep
|
||||
|
||||
crawl_router = APIRouter()
|
||||
|
||||
@crawl_router.post("/crawl/", dependencies=[Depends(JWTBearer())])
|
||||
async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: CrawlWebsite, enable_summarization: bool = False, credentials: dict = Depends(JWTBearer())):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
def get_unique_user_data(commons, user):
|
||||
user_vectors_response = commons['supabase'].table("vectors").select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size", count="exact") \
|
||||
.filter("user_id", "eq", user.email)\
|
||||
@ -28,21 +21,26 @@ async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: Cr
|
||||
documents = user_vectors_response.data # Access the data from the response
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
user_unique_vectors = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return user_unique_vectors
|
||||
|
||||
@crawl_router.post("/crawl/", dependencies=[Depends(JWTBearer())])
|
||||
async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: CrawlWebsite, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
|
||||
user_unique_vectors = get_unique_user_data(commons, current_user)
|
||||
|
||||
current_brain_size = sum(float(doc['size']) for doc in user_unique_vectors)
|
||||
|
||||
file_size = 1000000
|
||||
|
||||
remaining_free_space = float(max_brain_size) - (current_brain_size)
|
||||
|
||||
if remaining_free_space - file_size < 0:
|
||||
message = {"message": f"❌ User's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}", "type": "error"}
|
||||
else:
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
if not crawl_website.checkGithub():
|
||||
|
||||
file_path, file_name = crawl_website.process()
|
||||
|
||||
# Create a SpooledTemporaryFile from the file_path
|
||||
spooled_file = SpooledTemporaryFile()
|
||||
with open(file_path, 'rb') as f:
|
||||
@ -50,7 +48,7 @@ async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: Cr
|
||||
|
||||
# Pass the SpooledTemporaryFile to UploadFile
|
||||
file = UploadFile(file=spooled_file, filename=file_name)
|
||||
message = await filter_file(file, enable_summarization, commons['supabase'], user=user, openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
message = await filter_file(file, enable_summarization, commons['supabase'], user=current_user, openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
return message
|
||||
else:
|
||||
message = await process_github(crawl_website.url, "false", user=user, supabase=commons['supabase'], user_openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
message = await process_github(crawl_website.url, "false", user=current_user, supabase=commons['supabase'], user_openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
|
@ -1,21 +1,22 @@
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends
|
||||
from models.users import User
|
||||
from utils.vectors import CommonsDep
|
||||
|
||||
explore_router = APIRouter()
|
||||
|
||||
@explore_router.get("/explore", dependencies=[Depends(JWTBearer())])
|
||||
async def explore_endpoint(commons: CommonsDep,credentials: dict = Depends(JWTBearer()) ):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
def get_unique_user_data(commons, user):
|
||||
response = commons['supabase'].table("vectors").select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size", count="exact").filter("user_id", "eq", user.email).execute()
|
||||
documents = response.data # Access the data from the response
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
# Sort the list of documents by size in decreasing order
|
||||
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
|
||||
return unique_data
|
||||
|
||||
@explore_router.get("/explore", dependencies=[Depends(JWTBearer())])
|
||||
async def explore_endpoint(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
unique_data = get_unique_user_data(commons, current_user)
|
||||
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
|
||||
return {"documents": unique_data}
|
||||
|
||||
|
||||
@ -29,12 +30,9 @@ async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict
|
||||
{"metadata->>file_name": file_name, "user_id": user.email}).execute()
|
||||
return {"message": f"{file_name} of user {user.email} has been deleted."}
|
||||
|
||||
|
||||
@explore_router.get("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
|
||||
async def download_endpoint(commons: CommonsDep, file_name: str,credentials: dict = Depends(JWTBearer()) ):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
async def download_endpoint(commons: CommonsDep, file_name: str, current_user: User = Depends(get_current_user)):
|
||||
response = commons['supabase'].table("vectors").select(
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": user.email}).execute()
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute()
|
||||
documents = response.data
|
||||
# Returns all documents with the same file name
|
||||
return {"documents": documents}
|
||||
|
@ -4,4 +4,4 @@ misc_router = APIRouter()
|
||||
|
||||
@misc_router.get("/")
|
||||
async def root():
|
||||
return {"status": "OK"}
|
||||
return {"status": "OK"}
|
||||
|
@ -1,10 +1,7 @@
|
||||
import os
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from crawl.crawler import CrawlWebsite
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
from utils.file import convert_bytes, get_file_size
|
||||
from utils.processors import filter_file
|
||||
@ -12,13 +9,7 @@ from utils.vectors import CommonsDep
|
||||
|
||||
upload_router = APIRouter()
|
||||
|
||||
@upload_router.post("/upload", dependencies=[Depends(JWTBearer())])
|
||||
async def upload_file(request: Request,commons: CommonsDep, file: UploadFile, enable_summarization: bool = False, credentials: dict = Depends(JWTBearer())):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
remaining_free_space = 0
|
||||
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
def get_user_vectors(commons, user):
|
||||
user_vectors_response = commons['supabase'].table("vectors").select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size", count="exact") \
|
||||
.filter("user_id", "eq", user.email)\
|
||||
@ -26,19 +17,27 @@ async def upload_file(request: Request,commons: CommonsDep, file: UploadFile, e
|
||||
documents = user_vectors_response.data # Access the data from the response
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
user_unique_vectors = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return user_unique_vectors
|
||||
|
||||
def calculate_remaining_space(request, max_brain_size, max_brain_size_with_own_key, current_brain_size):
|
||||
remaining_free_space = float(max_brain_size_with_own_key) - current_brain_size if request.headers.get('Openai-Api-Key') else float(max_brain_size) - current_brain_size
|
||||
return remaining_free_space
|
||||
|
||||
@upload_router.post("/upload", dependencies=[Depends(JWTBearer())])
|
||||
async def upload_file(request: Request, commons: CommonsDep, file: UploadFile, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
|
||||
user_unique_vectors = get_user_vectors(commons, current_user)
|
||||
current_brain_size = sum(float(doc['size']) for doc in user_unique_vectors)
|
||||
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
remaining_free_space = float(max_brain_size_with_own_key) - (current_brain_size)
|
||||
else:
|
||||
remaining_free_space = float(max_brain_size) - (current_brain_size)
|
||||
remaining_free_space = calculate_remaining_space(request, max_brain_size, max_brain_size_with_own_key, current_brain_size)
|
||||
|
||||
file_size = get_file_size(file)
|
||||
|
||||
if remaining_free_space - file_size < 0:
|
||||
message = {"message": f"❌ User's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}", "type": "error"}
|
||||
else:
|
||||
message = await filter_file(file, enable_summarization, commons['supabase'], user, openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
message = await filter_file(file, enable_summarization, commons['supabase'], current_user, openai_api_key=request.headers.get('Openai-Api-Key', None))
|
||||
|
||||
return message
|
||||
|
||||
|
@ -1,51 +1,53 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from models.users import User
|
||||
from utils.vectors import CommonsDep
|
||||
|
||||
user_router = APIRouter()
|
||||
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
@user_router.get("/user", dependencies=[Depends(JWTBearer())])
|
||||
async def get_user_endpoint(request: Request,commons: CommonsDep, credentials: dict = Depends(JWTBearer())):
|
||||
|
||||
# Create a function that returns the unique documents out of the vectors
|
||||
# Create a function that returns the list of documents that can take in what to put in the select + the filter
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
# Cascade delete the summary from the database first, because it has a foreign key constraint
|
||||
|
||||
MAX_BRAIN_SIZE_WITH_OWN_KEY = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
|
||||
def get_unique_documents(vectors):
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
return [dict(t) for t in set(tuple(d.items()) for d in vectors)]
|
||||
|
||||
def get_user_vectors(commons, email):
|
||||
# Access the supabase table and get the vectors
|
||||
user_vectors_response = commons['supabase'].table("vectors").select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size", count="exact") \
|
||||
.filter("user_id", "eq", user.email)\
|
||||
.filter("user_id", "eq", email)\
|
||||
.execute()
|
||||
documents = user_vectors_response.data # Access the data from the response
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
user_unique_vectors = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return user_vectors_response.data
|
||||
|
||||
current_brain_size = sum(float(doc['size']) for doc in user_unique_vectors)
|
||||
def get_user_request_stats(commons, email):
|
||||
requests_stats = commons['supabase'].from_('users').select(
|
||||
'*').filter("email", "eq", email).execute()
|
||||
return requests_stats.data
|
||||
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
@user_router.get("/user", dependencies=[Depends(JWTBearer())])
|
||||
async def get_user_endpoint(request: Request, commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
|
||||
user_vectors = get_user_vectors(commons, current_user.email)
|
||||
user_unique_vectors = get_unique_documents(user_vectors)
|
||||
|
||||
current_brain_size = sum(float(doc.get('size', 0)) for doc in user_unique_vectors)
|
||||
|
||||
max_brain_size = int(os.getenv("MAX_BRAIN_SIZE", 0))
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
max_brain_size = max_brain_size_with_own_key
|
||||
max_brain_size = MAX_BRAIN_SIZE_WITH_OWN_KEY
|
||||
|
||||
# Create function get user request stats -> nombre de requetes par jour + max number of requests -> svg to display the number of requests ? une fusee ?
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
|
||||
requests_stats = get_user_request_stats(commons, current_user.email)
|
||||
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
max_brain_size = max_brain_size_with_own_key
|
||||
|
||||
requests_stats = commons['supabase'].from_('users').select(
|
||||
'*').filter("email", "eq", user.email).execute()
|
||||
|
||||
return {"email":user.email,
|
||||
return {"email": current_user.email,
|
||||
"max_brain_size": max_brain_size,
|
||||
"current_brain_size": current_brain_size,
|
||||
"max_requests_number": max_requests_number,
|
||||
"requests_stats" : requests_stats.data,
|
||||
"requests_stats" : requests_stats,
|
||||
"date": date,
|
||||
}
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
from typing import Annotated, List, Tuple
|
||||
from typing import Annotated
|
||||
|
||||
from auth.auth_bearer import JWTBearer
|
||||
from fastapi import Depends, UploadFile
|
||||
from fastapi import Depends
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import SupabaseVectorStore
|
||||
@ -11,17 +10,16 @@ from llm.summarization import llm_evaluate_summaries, llm_summerize
|
||||
from logger import get_logger
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
from pydantic import BaseModel
|
||||
|
||||
from supabase import Client, create_client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
supabase_url = os.environ.get("SUPABASE_URL")
|
||||
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
|
||||
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
||||
supabase_client: Client = create_client(supabase_url, supabase_key)
|
||||
documents_vector_store = SupabaseVectorStore(
|
||||
@ -30,9 +28,6 @@ summaries_vector_store = SupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="summaries")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def common_dependencies():
|
||||
return {
|
||||
"supabase": supabase_client,
|
||||
@ -45,8 +40,6 @@ def common_dependencies():
|
||||
CommonsDep = Annotated[dict, Depends(common_dependencies)]
|
||||
|
||||
|
||||
|
||||
|
||||
def create_summary(document_id, content, metadata):
|
||||
logger.info(f"Summarizing document {content[:100]}")
|
||||
summary = llm_summerize(content)
|
||||
@ -64,9 +57,8 @@ def create_vector(user_id,doc, user_openai_api_key=None):
|
||||
logger.info(f"Creating vector for document")
|
||||
logger.info(f"Document: {doc}")
|
||||
if user_openai_api_key:
|
||||
documents_vector_store._embedding = embeddings_request = OpenAIEmbeddings(openai_api_key=user_openai_api_key)
|
||||
documents_vector_store._embedding = OpenAIEmbeddings(openai_api_key=user_openai_api_key)
|
||||
try:
|
||||
|
||||
sids = documents_vector_store.add_documents(
|
||||
[doc])
|
||||
if sids and len(sids) > 0:
|
||||
@ -110,8 +102,6 @@ def update_chat(chat_id, history):
|
||||
def create_embedding(content):
|
||||
return embeddings.embed_query(content)
|
||||
|
||||
|
||||
|
||||
def similarity_search(query, table='match_summaries', top_k=5, threshold=0.5):
|
||||
query_embedding = create_embedding(query)
|
||||
summaries = supabase_client.rpc(
|
||||
@ -120,10 +110,6 @@ def similarity_search(query, table='match_summaries', top_k=5, threshold=0.5):
|
||||
).execute()
|
||||
return summaries.data
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def fetch_user_id_from_credentials(commons: CommonsDep,date,credentials):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
|
||||
@ -139,22 +125,19 @@ def fetch_user_id_from_credentials(commons: CommonsDep,date,credentials):
|
||||
else:
|
||||
user_id = userItem['user_id']
|
||||
|
||||
# if not(user_id):
|
||||
# throw error
|
||||
return user_id
|
||||
|
||||
def get_chat_name_from_first_question(chat_message: ChatMessage):
|
||||
# Step 1: Get the summary of the first question
|
||||
# first_question_summary = summerize_as_title(chat_message.question)
|
||||
# first_question_summary = summarize_as_title(chat_message.question)
|
||||
# Step 2: Process this summary to create a chat name by selecting the first three words
|
||||
chat_name = ' '.join(chat_message.question.split()[:3])
|
||||
print('chat_name')
|
||||
|
||||
return chat_name
|
||||
|
||||
def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user_openai_api_key:str):
|
||||
qa = get_qa_llm(chat_message, email, user_openai_api_key)
|
||||
|
||||
|
||||
if chat_message.use_summarization:
|
||||
# 1. get summaries from the vector store based on question
|
||||
summaries = similarity_search(
|
||||
@ -163,13 +146,12 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user
|
||||
evaluations = llm_evaluate_summaries(
|
||||
chat_message.question, summaries, chat_message.model)
|
||||
# 3. pull in the top documents from summaries
|
||||
# logger.info('Evaluations: %s', evaluations)
|
||||
if evaluations:
|
||||
reponse = commons['supabase'].from_('vectors').select(
|
||||
response = commons['supabase'].from_('vectors').select(
|
||||
'*').in_('id', values=[e['document_id'] for e in evaluations]).execute()
|
||||
# 4. use top docs as additional context
|
||||
additional_context = '---\nAdditional Context={}'.format(
|
||||
'---\n'.join(data['content'] for data in reponse.data)
|
||||
'---\n'.join(data['content'] for data in response.data)
|
||||
) + '\n'
|
||||
model_response = qa(
|
||||
{"question": additional_context + chat_message.question})
|
||||
@ -192,5 +174,3 @@ def get_answer(commons: CommonsDep, chat_message: ChatMessage, email: str, user
|
||||
answer = answer + "\n\nRef: " + "; ".join(files)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
|
@ -57,7 +57,7 @@ END;
|
||||
$$;
|
||||
|
||||
-- Create stats table
|
||||
CREATE TABLE stats (
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
time TIMESTAMP,
|
||||
chat BOOLEAN,
|
||||
embedding BOOLEAN,
|
||||
|
Loading…
Reference in New Issue
Block a user