refactor: delete common_dependencies function (#843)

* use function for get_documents_vector_store

* use function for get_embeddings

* use function for get_supabase_client

* use function for get_supabase_db

* delete lasts common_dependencies
This commit is contained in:
ChloeMouret 2023-08-03 20:24:42 +02:00 committed by GitHub
parent b3fb8fc3bc
commit 711e9fb8c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 303 additions and 347 deletions

View File

@ -1,7 +1,7 @@
from datetime import datetime
from fastapi import HTTPException
from models.settings import common_dependencies
from models.settings import get_supabase_db
from models.users import User
from pydantic import DateError
@ -12,8 +12,8 @@ async def verify_api_key(
try:
# Use UTC time to avoid timezone issues
current_date = datetime.utcnow().date()
commons = common_dependencies()
result = commons["db"].get_active_api_key(api_key)
supabase_db = get_supabase_db()
result = supabase_db.get_active_api_key(api_key)
if result.data is not None and len(result.data) > 0:
api_key_creation_date = datetime.strptime(
@ -33,10 +33,10 @@ async def verify_api_key(
async def get_user_from_api_key(
api_key: str,
) -> User:
commons = common_dependencies()
supabase_db = get_supabase_db()
# Lookup the user_id from the api_keys table
user_id_data = commons["db"].get_user_id_by_api_key(api_key)
user_id_data = supabase_db.get_user_id_by_api_key(api_key)
if not user_id_data.data:
raise HTTPException(status_code=400, detail="Invalid API key.")
@ -44,7 +44,7 @@ async def get_user_from_api_key(
user_id = user_id_data.data[0]["user_id"]
# Lookup the email from the users table. Todo: remove and use user_id for credentials
user_email_data = commons["db"].get_user_email(user_id)
user_email_data = supabase_db.get_user_email(user_id)
email = user_email_data.data[0]["email"] if user_email_data.data else None
return User(email=email, id=user_id)

View File

@ -2,11 +2,12 @@ from typing import Any, List, Optional
from uuid import UUID
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import BrainRateLimiting, get_supabase_client, get_supabase_db
from pydantic import BaseModel
from supabase.client import Client
from utils.vectors import get_unique_files_from_vector_ids
from models.settings import BrainRateLimiting, CommonsDep, common_dependencies
logger = get_logger(__name__)
@ -27,8 +28,12 @@ class Brain(BaseModel):
arbitrary_types_allowed = True
@property
def commons(self) -> CommonsDep:
return common_dependencies()
def supabase_client(self) -> Client:
return get_supabase_client()
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
@property
def brain_size(self):
@ -46,7 +51,7 @@ class Brain(BaseModel):
@classmethod
def create(cls, *args, **kwargs):
commons = common_dependencies()
commons = {"supabase": get_supabase_client()}
return cls(
commons=commons, *args, **kwargs # pyright: ignore reportPrivateUsage=none
) # pyright: ignore reportPrivateUsage=none
@ -54,8 +59,7 @@ class Brain(BaseModel):
# TODO: move this to a brand new BrainService
def get_brain_users(self):
response = (
self.commons["supabase"]
.table("brains_users")
self.supabase_client.table("brains_users")
.select("id:brain_id, *")
.filter("brain_id", "eq", self.id)
.execute()
@ -65,33 +69,32 @@ class Brain(BaseModel):
# TODO: move this to a brand new BrainService
def delete_user_from_brain(self, user_id):
results = (
self.commons["supabase"]
.table("brains_users")
self.supabase_client.table("brains_users")
.select("*")
.match({"brain_id": self.id, "user_id": user_id})
.execute()
)
if len(results.data) != 0:
self.commons["supabase"].table("brains_users").delete().match(
self.supabase_client.table("brains_users").delete().match(
{"brain_id": self.id, "user_id": user_id}
).execute()
def delete_brain(self, user_id):
results = self.commons["db"].delete_brain_user_by_id(user_id, self.id)
results = self.supabase_db.delete_brain_user_by_id(user_id, self.id)
if len(results.data) == 0:
return {"message": "You are not the owner of this brain."}
else:
self.commons["db"].delete_brain_vector(self.id)
self.commons["db"].delete_brain_user(self.id)
self.commons["db"].delete_brain(self.id)
self.supabase_db.delete_brain_vector(self.id)
self.supabase_db.delete_brain_user(self.id)
self.supabase_db.delete_brain(self.id)
def create_brain_vector(self, vector_id, file_sha1):
return self.commons["db"].create_brain_vector(self.id, vector_id, file_sha1)
return self.supabase_db.create_brain_vector(self.id, vector_id, file_sha1)
def get_vector_ids_from_file_sha1(self, file_sha1: str):
return self.commons["db"].get_vector_ids_from_file_sha1(file_sha1)
return self.supabase_db.get_vector_ids_from_file_sha1(file_sha1)
def update_brain_with_file(self, file_sha1: str):
# not used
@ -104,10 +107,10 @@ class Brain(BaseModel):
Retrieve unique brain data (i.e. uploaded files and crawled websites).
"""
vector_ids = self.commons["db"].get_brain_vector_ids(self.id)
vector_ids = self.supabase_db.get_brain_vector_ids(self.id)
self.files = get_unique_files_from_vector_ids(vector_ids)
return self.files
def delete_file_from_brain(self, file_name: str):
return self.commons["db"].delete_file_from_brain(self.id, file_name)
return self.supabase_db.delete_file_from_brain(self.id, file_name)

View File

@ -1,9 +1,9 @@
from uuid import UUID
from logger import get_logger
from models.settings import get_supabase_client
from pydantic import BaseModel
from models.settings import CommonsDep, common_dependencies
from supabase.client import Client
logger = get_logger(__name__)
@ -17,14 +17,13 @@ class BrainSubscription(BaseModel):
arbitrary_types_allowed = True
@property
def commons(self) -> CommonsDep:
return common_dependencies()
def supabase_client(self) -> Client:
return get_supabase_client()
def create_subscription_invitation(self):
logger.info("Creating subscription invitation")
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.insert(
{
"brain_id": str(self.brain_id),
@ -39,8 +38,7 @@ class BrainSubscription(BaseModel):
def update_subscription_invitation(self):
logger.info("Updating subscription invitation")
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.update({"rights": self.rights})
.eq("brain_id", str(self.brain_id))
.eq("email", self.email)
@ -50,8 +48,7 @@ class BrainSubscription(BaseModel):
def create_or_update_subscription_invitation(self):
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.select("*")
.eq("brain_id", str(self.brain_id))
.eq("email", self.email)

View File

@ -7,7 +7,8 @@ from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from logger import get_logger
from models.brains import Brain
from models.settings import CommonsDep, common_dependencies
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from pydantic import BaseModel
from utils.file import compute_sha1_from_file
@ -26,11 +27,11 @@ class File(BaseModel):
chunk_size: int = 500
chunk_overlap: int = 0
documents: Optional[Any] = None
_commons: Optional[CommonsDep] = None
@property
def commons(self) -> CommonsDep:
return common_dependencies()
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -98,7 +99,9 @@ class File(BaseModel):
Set the vectors_ids property with the ids of the vectors
that are associated with the file in the vectors table
"""
self.vectors_ids = self.commons["db"].get_vectors_by_file_sha1(self.file_sha1).data
self.vectors_ids = self.supabase_db.get_vectors_by_file_sha1(
self.file_sha1
).data
def file_already_exists(self):
"""
@ -126,7 +129,9 @@ class File(BaseModel):
Args:
brain_id (str): Brain id
"""
response = self.commons["db"].get_brain_vectors_by_brain_id_and_file_sha1(brain_id, self.file_sha1)
response = self.supabase_db.get_brain_vectors_by_brain_id_and_file_sha1(
brain_id, self.file_sha1
)
print("response.data", response.data)
if len(response.data) == 0:

View File

@ -1,6 +1,3 @@
from typing import Annotated, TypedDict
from fastapi import Depends
from langchain.embeddings.openai import OpenAIEmbeddings
from models.databases.supabase.supabase import SupabaseDB
from pydantic import BaseSettings
@ -28,34 +25,34 @@ class LLMSettings(BaseSettings):
model_path: str = "./local_models/ggml-gpt4all-j-v1.3-groovy.bin"
class CommonDependencies(TypedDict):
supabase: Client
db: SupabaseDB
embeddings: OpenAIEmbeddings
documents_vector_store: SupabaseVectorStore
def get_supabase_client() -> Client:
settings = BrainSettings() # pyright: ignore reportPrivateUsage=none
supabase_client: Client = create_client(
settings.supabase_url, settings.supabase_service_key
)
return supabase_client
def common_dependencies() -> CommonDependencies:
def get_supabase_db() -> SupabaseDB:
supabase_client = get_supabase_client()
return SupabaseDB(supabase_client)
def get_embeddings() -> OpenAIEmbeddings:
settings = BrainSettings() # pyright: ignore reportPrivateUsage=none
embeddings = OpenAIEmbeddings(
openai_api_key=settings.openai_api_key
) # pyright: ignore reportPrivateUsage=none
return embeddings
def get_documents_vector_store() -> SupabaseVectorStore:
settings = BrainSettings() # pyright: ignore reportPrivateUsage=none
embeddings = get_embeddings()
supabase_client: Client = create_client(
settings.supabase_url, settings.supabase_service_key
)
documents_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="vectors"
)
db = None
db = SupabaseDB(supabase_client)
return {
"supabase": supabase_client,
"db": db,
"embeddings": embeddings,
"documents_vector_store": documents_vector_store,
}
CommonsDep = Annotated[dict, Depends(common_dependencies)]
return documents_vector_store

View File

@ -2,13 +2,10 @@ from typing import Optional
from uuid import UUID
from logger import get_logger
from models.settings import common_dependencies, CommonsDep
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from pydantic import BaseModel
from models.settings import common_dependencies
logger = get_logger(__name__)
@ -18,10 +15,10 @@ class User(BaseModel):
email: Optional[str]
user_openai_api_key: Optional[str] = None
requests_count: int = 0
@property
def commons(self) -> CommonsDep:
return common_dependencies()
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
# [TODO] Rename the user table and its references to 'user_usage'
@ -34,13 +31,13 @@ class User(BaseModel):
"""
logger.info(f"New user entry in db document for user {self.email}")
return self.commons["db"].create_user(self.id, self.email, date)
return self.supabase_db.create_user(self.id, self.email, date)
def get_user_request_stats(self):
"""
Fetch the user request stats from the database
"""
request = self.commons["db"].get_user_request_stats(self.id)
request = self.supabase_db.get_user_request_stats(self.id)
return request.data
@ -48,11 +45,11 @@ class User(BaseModel):
"""
Increment the user request count in the database
"""
response = self.commons["db"].fetch_user_requests_count(self.id, date)
response = self.supabase_db.fetch_user_requests_count(self.id, date)
userItem = next(iter(response.data or []), {"requests_count": 0})
requests_count = userItem["requests_count"] + 1
logger.info(f"User {self.email} request count updated to {requests_count}")
self.commons["db"].update_user_request_count(self.id, requests_count, date)
self.supabase_db.update_user_request_count(self.id, requests_count, date)
self.requests_count = requests_count

View File

@ -6,12 +6,11 @@ import openai
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from models.files import File
from models.settings import CommonsDep
from models.settings import get_documents_vector_store
from utils.file import compute_sha1_from_content
async def process_audio(
commons: CommonsDep, # pyright: ignore reportPrivateUsage=none
file: File,
enable_summarization: bool,
user,
@ -21,6 +20,7 @@ async def process_audio(
file_sha = ""
dateshort = time.strftime("%Y%m%d-%H%M%S")
file_meta_name = f"audiotranscript_{dateshort}.txt"
documents_vector_store = get_documents_vector_store()
# use this for whisper
os.environ.get("OPENAI_API_KEY")
@ -78,9 +78,7 @@ async def process_audio(
for text in texts
]
commons.documents_vector_store.add_documents( # pyright: ignore reportPrivateUsage=none
docs_with_metadata
)
documents_vector_store.add_documents(docs_with_metadata)
finally:
if temp_filename and os.path.exists(temp_filename):

View File

@ -3,12 +3,10 @@ import time
from langchain.schema import Document
from models.brains import Brain
from models.files import File
from models.settings import CommonsDep
from utils.vectors import Neurons
async def process_file(
commons: CommonsDep,
file: File,
loader_class,
enable_summarization,
@ -31,7 +29,7 @@ async def process_file(
}
doc_with_metadata = Document(page_content=doc.page_content, metadata=metadata)
neurons = Neurons(commons=commons)
neurons = Neurons()
created_vector = neurons.create_vector(doc_with_metadata, user_openai_api_key)
# add_usage(stats_db, "embedding", "audio", metadata={"file_name": file_meta_name,"file_type": ".txt", "chunk_size": chunk_size, "chunk_overlap": chunk_overlap})

View File

@ -1,22 +1,19 @@
from langchain.document_loaders import CSVLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_csv(
commons: CommonsDep,
file: File,
enable_summarization,
brain_id,
user_openai_api_key,
):
return process_file(
commons,
file,
CSVLoader,
enable_summarization,
brain_id,
user_openai_api_key,
file=file,
loader_class=CSVLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import Docx2txtLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_docx(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, Docx2txtLoader, enable_summarization, brain_id, user_openai_api_key)
def process_docx(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=Docx2txtLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders.epub import UnstructuredEPubLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_epub(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, UnstructuredEPubLoader, enable_summarization, brain_id, user_openai_api_key)
def process_epub(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=UnstructuredEPubLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -6,13 +6,11 @@ from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from models.brains import Brain
from models.files import File
from models.settings import CommonsDep
from utils.file import compute_sha1_from_content
from utils.vectors import Neurons
async def process_github(
commons: CommonsDep, # pyright: ignore reportPrivateUsage=none
repo,
enable_summarization,
brain_id,
@ -70,7 +68,7 @@ async def process_github(
if not file_exists:
print(f"Creating entry for file {file.file_sha1} in vectors...")
neurons = Neurons(commons=commons)
neurons = Neurons()
created_vector = neurons.create_vector(
doc_with_metadata, user_openai_api_key
)

View File

@ -1,30 +1,14 @@
import re
import unicodedata
import requests
from langchain.document_loaders import UnstructuredHTMLLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_html(
commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key
):
def process_html(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
commons,
file,
UnstructuredHTMLLoader,
enable_summarization,
brain_id,
user_openai_api_key,
file=file,
loader_class=UnstructuredHTMLLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)
def get_html(url):
response = requests.get(url)
if response.status_code == 200:
return response.text
else:
return None

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import UnstructuredMarkdownLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_markdown(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, UnstructuredMarkdownLoader, enable_summarization, brain_id, user_openai_api_key)
def process_markdown(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=UnstructuredMarkdownLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import NotebookLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_ipnyb(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, NotebookLoader, enable_summarization, brain_id, user_openai_api_key)
def process_ipnyb(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=NotebookLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import PyMuPDFLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_odt(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, PyMuPDFLoader, enable_summarization, brain_id, user_openai_api_key)
def process_odt(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=PyMuPDFLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,10 +1,14 @@
from langchain.document_loaders import PyMuPDFLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_pdf(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, PyMuPDFLoader, enable_summarization, brain_id, user_openai_api_key)
def process_pdf(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=PyMuPDFLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import UnstructuredPowerPointLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
def process_powerpoint(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(commons, file, UnstructuredPowerPointLoader, enable_summarization, brain_id, user_openai_api_key)
def process_powerpoint(file: File, enable_summarization, brain_id, user_openai_api_key):
return process_file(
file=file,
loader_class=UnstructuredPowerPointLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,14 @@
from langchain.document_loaders import TextLoader
from models.files import File
from models.settings import CommonsDep
from .common import process_file
async def process_txt(commons: CommonsDep, file: File, enable_summarization, brain_id, user_openai_api_key):
return await process_file(commons, file, TextLoader, enable_summarization, brain_id,user_openai_api_key)
async def process_txt(file: File, enable_summarization, brain_id, user_openai_api_key):
return await process_file(
file=file,
loader_class=TextLoader,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=user_openai_api_key,
)

View File

@ -1,9 +1,9 @@
from models.brain_entity import BrainEntity
from models.databases.supabase.brains import CreateBrainProperties
from models.settings import common_dependencies
from models.settings import get_supabase_db
def create_brain(brain: CreateBrainProperties) -> BrainEntity:
commons = common_dependencies()
supabase_db = get_supabase_db()
return commons["db"].create_brain(brain.dict(exclude_unset=True))
return supabase_db.create_brain(brain.dict(exclude_unset=True))

View File

@ -1,14 +1,14 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_db
from routes.authorizations.types import RoleEnum
def create_brain_user(
user_id: UUID, brain_id: UUID, rights: RoleEnum, is_default_brain: bool
) -> None:
commons = common_dependencies()
commons["db"].create_brain_user(
supabase_db = get_supabase_db()
supabase_db.create_brain_user(
user_id=user_id,
brain_id=brain_id,
rights=rights,

View File

@ -1,10 +1,10 @@
from uuid import UUID
from models.brain_entity import BrainEntity
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_brain_by_id(brain_id: UUID) -> BrainEntity | None:
commons = common_dependencies()
supabase_db = get_supabase_db()
return commons["db"].get_brain_by_id(brain_id)
return supabase_db.get_brain_by_id(brain_id)

View File

@ -1,14 +1,13 @@
from uuid import UUID
from models.brain_entity import BrainEntity
from models.settings import common_dependencies
from models.settings import get_supabase_client
def get_brain_details(brain_id: UUID) -> BrainEntity | None:
commons = common_dependencies()
supabase_client = get_supabase_client()
response = (
commons["supabase"]
.from_("brains")
supabase_client.from_("brains")
.select("*")
.filter("brain_id", "eq", brain_id)
.execute()

View File

@ -1,9 +1,9 @@
from uuid import UUID
from models.brain_entity import MinimalBrainEntity
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_brain_for_user(user_id: UUID, brain_id: UUID) -> MinimalBrainEntity:
commons = common_dependencies()
return commons["db"].get_brain_for_user(user_id, brain_id)
supabase_db = get_supabase_db()
return supabase_db.get_brain_for_user(user_id, brain_id)

View File

@ -2,15 +2,15 @@ from uuid import UUID
from logger import get_logger
from models.brain_entity import BrainEntity
from models.settings import common_dependencies
from models.settings import get_supabase_db
from repository.brain.get_brain_by_id import get_brain_by_id
logger = get_logger(__name__)
def get_user_default_brain(user_id: UUID) -> BrainEntity | None:
commons = common_dependencies()
brain_id = commons["db"].get_default_user_brain_id(user_id)
supabase_db = get_supabase_db()
brain_id = supabase_db.get_default_user_brain_id(user_id)
logger.info("Default brain response:", brain_id)

View File

@ -1,11 +1,11 @@
from uuid import UUID
from models.brain_entity import BrainEntity
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_user_brains(user_id: UUID) -> list[BrainEntity]:
commons = common_dependencies()
results = commons["db"].get_user_brains(user_id)
supabase_db = get_supabase_db()
results = supabase_db.get_user_brains(user_id)
return results

View File

@ -1,19 +1,19 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
from repository.brain.get_default_user_brain import get_user_default_brain
def set_as_default_brain_for_user(user_id: UUID, brain_id: UUID):
commons = common_dependencies()
supabase_client = get_supabase_client()
old_default_brain = get_user_default_brain(user_id)
if old_default_brain is not None:
commons["supabase"].table("brains_users").update(
{"default_brain": False}
).match({"brain_id": old_default_brain.brain_id, "user_id": user_id}).execute()
supabase_client.table("brains_users").update({"default_brain": False}).match(
{"brain_id": old_default_brain.brain_id, "user_id": user_id}
).execute()
commons["supabase"].table("brains_users").update({"default_brain": True}).match(
supabase_client.table("brains_users").update({"default_brain": True}).match(
{"brain_id": brain_id, "user_id": user_id}
).execute()

View File

@ -2,11 +2,11 @@ from uuid import UUID
from models.brain_entity import BrainEntity
from models.databases.supabase.brains import BrainUpdatableProperties
from models.settings import common_dependencies
from models.settings import get_supabase_db
def update_brain_by_id(brain_id: UUID, brain: BrainUpdatableProperties) -> BrainEntity:
"""Update a prompt by id"""
commons = common_dependencies()
supabase_db = get_supabase_db()
return commons["db"].update_brain_by_id(brain_id, brain)
return supabase_db.update_brain_by_id(brain_id, brain)

View File

@ -1,12 +1,12 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
def update_brain_user_rights(brain_id: UUID, user_id: UUID, rights: str) -> None:
commons = common_dependencies()
supabase_client = get_supabase_client()
commons["supabase"].table("brains_users").update({"rights": rights}).eq(
supabase_client.table("brains_users").update({"rights": rights}).eq(
"brain_id",
brain_id,
).eq("user_id", user_id).execute()

View File

@ -1,21 +1,18 @@
from typing import Optional
from logger import get_logger
from models.brains_subscription_invitations import BrainSubscription
from models.settings import CommonsDep, common_dependencies
from models.settings import get_supabase_client
logger = get_logger(__name__)
class SubscriptionInvitationService:
def __init__(self, commons: Optional[CommonsDep] = None):
self.commons = common_dependencies()
def __init__(self):
self.supabase_client = get_supabase_client()
def create_subscription_invitation(self, brain_subscription: BrainSubscription):
logger.info("Creating subscription invitation")
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.insert(
{
"brain_id": str(brain_subscription.brain_id),
@ -30,8 +27,7 @@ class SubscriptionInvitationService:
def update_subscription_invitation(self, brain_subscription: BrainSubscription):
logger.info("Updating subscription invitation")
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.update({"rights": brain_subscription.rights})
.eq("brain_id", str(brain_subscription.brain_id))
.eq("email", brain_subscription.email)
@ -43,8 +39,7 @@ class SubscriptionInvitationService:
self, brain_subscription: BrainSubscription
):
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.select("*")
.eq("brain_id", str(brain_subscription.brain_id))
.eq("email", brain_subscription.email)
@ -61,8 +56,7 @@ class SubscriptionInvitationService:
def fetch_invitation(self, subscription: BrainSubscription):
logger.info("Fetching subscription invitation")
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.select("*")
.eq("brain_id", str(subscription.brain_id))
.eq("email", subscription.email)
@ -78,8 +72,7 @@ class SubscriptionInvitationService:
f"Removing subscription invitation for email {subscription.email} and brain {subscription.brain_id}"
)
response = (
self.commons["supabase"]
.table("brain_subscription_invitations")
self.supabase_client.table("brain_subscription_invitations")
.delete()
.eq("brain_id", str(subscription.brain_id))
.eq("email", subscription.email)

View File

@ -3,7 +3,7 @@ from uuid import UUID
from logger import get_logger
from models.chat import Chat
from models.settings import common_dependencies
from models.settings import get_supabase_db
logger = get_logger(__name__)
@ -17,7 +17,7 @@ class CreateChatProperties:
def create_chat(user_id: UUID, chat_data: CreateChatProperties) -> Chat:
commons = common_dependencies()
supabase_db = get_supabase_db()
# Chat is created upon the user's first question asked
logger.info(f"New chat entry in chats table for user {user_id}")
@ -27,7 +27,7 @@ def create_chat(user_id: UUID, chat_data: CreateChatProperties) -> Chat:
"user_id": str(user_id),
"chat_name": chat_data.name,
}
insert_response = commons["db"].create_chat(new_chat)
insert_response = supabase_db.create_chat(new_chat)
logger.info(f"Insert response {insert_response.data}")
return insert_response.data[0]

View File

@ -1,12 +1,9 @@
from models.chat import Chat
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_chat_by_id(chat_id: str) -> Chat:
commons = common_dependencies()
supabase_db = get_supabase_db()
response = (
commons["db"]
.get_chat_by_id(chat_id)
)
response = supabase_db.get_chat_by_id(chat_id)
return Chat(response.data[0])

View File

@ -1,12 +1,12 @@
from typing import List # For type hinting
from typing import List
from models.chat import ChatHistory
from models.settings import common_dependencies
from models.settings import get_supabase_db # For type hinting
def get_chat_history(chat_id: str) -> List[ChatHistory]:
commons = common_dependencies()
history: List[ChatHistory] = commons["db"].get_chat_history(chat_id).data
supabase_db = get_supabase_db()
history: List[ChatHistory] = supabase_db.get_chat_history(chat_id).data
if history is None:
return []
else:

View File

@ -1,11 +1,11 @@
from typing import List
from models.chat import Chat
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_user_chats(user_id: str) -> List[Chat]:
commons = common_dependencies()
response = commons["db"].get_user_chats(user_id)
supabase_db = get_supabase_db()
response = supabase_db.get_user_chats(user_id)
chats = [Chat(chat_dict) for chat_dict in response.data]
return chats

View File

@ -3,7 +3,7 @@ from typing import Optional
from logger import get_logger
from models.chat import Chat
from models.settings import common_dependencies
from models.settings import get_supabase_db
logger = get_logger(__name__)
@ -17,7 +17,7 @@ class ChatUpdatableProperties:
def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
commons = common_dependencies()
supabase_db = get_supabase_db()
if not chat_id:
logger.error("No chat_id provided")
@ -31,10 +31,7 @@ def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
updated_chat = None
if updates:
updated_chat = (
commons["db"]
.update_chat(chat_id, updates)
).data[0]
updated_chat = (supabase_db.update_chat(chat_id, updates)).data[0]
logger.info(f"Chat {chat_id} updated")
else:
logger.info(f"No updates to apply for chat {chat_id}")

View File

@ -2,13 +2,13 @@ from typing import List # For type hinting
from fastapi import HTTPException
from models.chat import ChatHistory
from models.settings import common_dependencies
from models.settings import get_supabase_db
def update_chat_history(chat_id: str, user_message: str, assistant: str) -> ChatHistory:
commons = common_dependencies()
supabase_db = get_supabase_db()
response: List[ChatHistory] = (
commons["db"].update_chat_history(chat_id, user_message, assistant)
supabase_db.update_chat_history(chat_id, user_message, assistant)
).data
if len(response) == 0:
raise HTTPException(

View File

@ -1,6 +1,6 @@
from logger import get_logger
from models.chat import ChatHistory
from models.settings import common_dependencies
from models.settings import get_supabase_db
logger = get_logger(__name__)
@ -10,7 +10,7 @@ def update_message_by_id(
user_message: str = None, # pyright: ignore reportPrivateUsage=none
assistant: str = None, # pyright: ignore reportPrivateUsage=none
) -> ChatHistory:
commons = common_dependencies()
supabase_db = get_supabase_db()
if not message_id:
logger.error("No message_id provided")
@ -27,10 +27,9 @@ def update_message_by_id(
updated_message = None
if updates:
updated_message = (
commons["db"]
.update_message_by_id(message_id, updates)
).data[0]
updated_message = (supabase_db.update_message_by_id(message_id, updates)).data[
0
]
logger.info(f"Message {message_id} updated")
else:
logger.info(f"No updates to apply for message {message_id}")

View File

@ -1,9 +1,9 @@
from models.databases.supabase.prompts import CreatePromptProperties
from models.prompt import Prompt
from models.settings import common_dependencies
from models.settings import get_supabase_db
def create_prompt(prompt: CreatePromptProperties) -> Prompt:
commons = common_dependencies()
supabase_db = get_supabase_db()
return commons["db"].create_prompt(prompt)
return supabase_db.create_prompt(prompt)

View File

@ -1,7 +1,7 @@
from uuid import UUID
from models.databases.supabase.prompts import DeletePromptResponse
from models.settings import common_dependencies
from models.settings import get_supabase_db
def delete_prompt_by_id(prompt_id: UUID) -> DeletePromptResponse:
@ -13,5 +13,5 @@ def delete_prompt_by_id(prompt_id: UUID) -> DeletePromptResponse:
Returns:
Prompt: The prompt
"""
commons = common_dependencies()
return commons["db"].delete_prompt_by_id(prompt_id)
supabase_db = get_supabase_db()
return supabase_db.delete_prompt_by_id(prompt_id)

View File

@ -1,7 +1,7 @@
from uuid import UUID
from models.prompt import Prompt
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_prompt_by_id(prompt_id: UUID) -> Prompt | None:
@ -14,5 +14,5 @@ def get_prompt_by_id(prompt_id: UUID) -> Prompt | None:
Returns:
Prompt: The prompt
"""
commons = common_dependencies()
return commons["db"].get_prompt_by_id(prompt_id)
supabase_db = get_supabase_db()
return supabase_db.get_prompt_by_id(prompt_id)

View File

@ -1,10 +1,10 @@
from models.prompt import Prompt
from models.settings import common_dependencies
from models.settings import get_supabase_db
def get_public_prompts() -> list[Prompt]:
"""
List all public prompts
"""
commons = common_dependencies()
return commons["db"].get_public_prompts()
supabase_db = get_supabase_db()
return supabase_db.get_public_prompts()

View File

@ -2,11 +2,11 @@ from uuid import UUID
from models.databases.supabase.prompts import PromptUpdatableProperties
from models.prompt import Prompt
from models.settings import common_dependencies
from models.settings import get_supabase_db
def update_prompt_by_id(prompt_id: UUID, prompt: PromptUpdatableProperties) -> Prompt:
"""Update a prompt by id"""
commons = common_dependencies()
supabase_db = get_supabase_db()
return commons["db"].update_prompt_by_id(prompt_id, prompt)
return supabase_db.update_prompt_by_id(prompt_id, prompt)

View File

@ -1,13 +1,11 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
def get_user_email_by_user_id(user_id: UUID) -> str:
commons = common_dependencies()
response = (
commons["supabase"]
.rpc("get_user_email_by_user_id", {"user_id": user_id})
.execute()
)
supabase_client = get_supabase_client()
response = supabase_client.rpc(
"get_user_email_by_user_id", {"user_id": user_id}
).execute()
return response.data[0]["email"]

View File

@ -1,13 +1,11 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
def get_user_id_by_user_email(email: str) -> UUID:
commons = common_dependencies()
response = (
commons["supabase"]
.rpc("get_user_id_by_user_email", {"user_email": email})
.execute()
)
supabase_client = get_supabase_client()
response = supabase_client.rpc(
"get_user_id_by_user_email", {"user_email": email}
).execute()
return response.data[0]["user_id"]

View File

@ -1,13 +1,13 @@
from models.settings import common_dependencies
from models.settings import get_supabase_client
from models.user_identity import UserIdentity
def create_user_identity(user_identity: UserIdentity) -> UserIdentity:
commons = common_dependencies()
supabase_client = get_supabase_client()
user_identity_dict = user_identity.dict()
user_identity_dict["user_id"] = str(user_identity.user_id)
response = (
commons["supabase"].from_("user_identity").insert(user_identity_dict).execute()
supabase_client.from_("user_identity").insert(user_identity_dict).execute()
)
return UserIdentity(**response.data[0])

View File

@ -1,15 +1,14 @@
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
from models.user_identity import UserIdentity
from repository.user_identity.create_user_identity import create_user_identity
def get_user_identity(user_id: UUID) -> UserIdentity:
commons = common_dependencies()
supabase_client = get_supabase_client()
response = (
commons["supabase"]
.from_("user_identity")
supabase_client.from_("user_identity")
.select("*")
.filter("user_id", "eq", user_id)
.execute()

View File

@ -1,12 +1,10 @@
from typing import Optional
from uuid import UUID
from models.settings import common_dependencies
from models.settings import get_supabase_client
from models.user_identity import UserIdentity
from pydantic import BaseModel
from repository.user_identity.create_user_identity import (
create_user_identity,
)
from repository.user_identity.create_user_identity import create_user_identity
class UserIdentityUpdatableProperties(BaseModel):
@ -17,10 +15,9 @@ def update_user_identity(
user_id: UUID,
user_identity_updatable_properties: UserIdentityUpdatableProperties,
) -> UserIdentity:
commons = common_dependencies()
supabase_client = get_supabase_client()
response = (
commons["supabase"]
.from_("user_identity")
supabase_client.from_("user_identity")
.update(user_identity_updatable_properties.__dict__)
.filter("user_id", "eq", user_id)
.execute()

View File

@ -6,7 +6,7 @@ from asyncpg.exceptions import UniqueViolationError
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from logger import get_logger
from models.settings import CommonsDep
from models.settings import get_supabase_db
from models.users import User
from pydantic import BaseModel
@ -32,9 +32,7 @@ api_key_router = APIRouter()
dependencies=[Depends(AuthBearer())],
tags=["API Key"],
)
async def create_api_key(
commons: CommonsDep, current_user: User = Depends(get_current_user)
):
async def create_api_key(current_user: User = Depends(get_current_user)):
"""
Create new API key for the current user.
@ -48,11 +46,12 @@ async def create_api_key(
new_key_id = uuid4()
new_api_key = token_hex(16)
api_key_inserted = False
supabase_db = get_supabase_db()
while not api_key_inserted:
try:
# Attempt to insert new API key into database
commons["db"].create_api_key(new_key_id, new_api_key, current_user.id)
supabase_db.create_api_key(new_key_id, new_api_key, current_user.id)
api_key_inserted = True
except UniqueViolationError:
@ -69,9 +68,7 @@ async def create_api_key(
@api_key_router.delete(
"/api-key/{key_id}", dependencies=[Depends(AuthBearer())], tags=["API Key"]
)
async def delete_api_key(
key_id: str, commons: CommonsDep, current_user: User = Depends(get_current_user)
):
async def delete_api_key(key_id: str, current_user: User = Depends(get_current_user)):
"""
Delete (deactivate) an API key for the current user.
@ -81,8 +78,8 @@ async def delete_api_key(
as inactive in the database.
"""
commons["db"].delete_api_key(key_id, current_user.id)
supabase_db = get_supabase_db()
supabase_db.delete_api_key(key_id, current_user.id)
return {"message": "API key deleted."}
@ -93,9 +90,7 @@ async def delete_api_key(
dependencies=[Depends(AuthBearer())],
tags=["API Key"],
)
async def get_api_keys(
commons: CommonsDep, current_user: User = Depends(get_current_user)
):
async def get_api_keys(current_user: User = Depends(get_current_user)):
"""
Get all active API keys for the current user.
@ -105,6 +100,6 @@ async def get_api_keys(
This endpoint retrieves all the active API keys associated with the current user. It returns a list of API key objects
containing the key ID and creation time for each API key.
"""
response = commons["db"].get_user_api_keys(current_user.id)
supabase_db = get_supabase_db()
response = supabase_db.get_user_api_keys(current_user.id)
return response.data

View File

@ -11,7 +11,8 @@ from llm.openai import OpenAIBrainPicking
from models.brains import Brain
from models.chat import Chat, ChatHistory
from models.chats import ChatQuestion
from models.settings import LLMSettings, common_dependencies
from models.databases.supabase.supabase import SupabaseDB
from models.settings import LLMSettings, get_supabase_db
from models.users import User
from repository.brain.get_brain_details import get_brain_details
from repository.brain.get_default_user_brain_or_create_new import (
@ -42,29 +43,19 @@ class NullableUUID(UUID):
return None
def get_chat_details(commons, chat_id):
return commons["db"].get_chat_details(chat_id)
def delete_chat_from_db(commons, chat_id):
def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
try:
commons["db"].delete_chat_history(chat_id)
supabase_db.delete_chat_history(chat_id)
except Exception as e:
print(e)
pass
try:
commons["db"].delete_chat(chat_id)
supabase_db.delete_chat(chat_id)
except Exception as e:
print(e)
pass
def fetch_user_stats(commons, user, date):
response = commons["db"].get_user_stats(user.email, date)
userItem = next(iter(response.data or []), {"requests_count": 0})
return userItem
def check_user_limit(
user: User,
):
@ -106,8 +97,8 @@ async def delete_chat(chat_id: UUID):
"""
Delete a specific chat by chat ID.
"""
commons = common_dependencies()
delete_chat_from_db(commons, chat_id)
supabase_db = get_supabase_db()
delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id)
return {"message": f"{chat_id} has been deleted."}

View File

@ -8,7 +8,6 @@ from crawl.crawler import CrawlWebsite
from fastapi import APIRouter, Depends, Query, Request, UploadFile
from models.brains import Brain
from models.files import File
from models.settings import common_dependencies
from models.users import User
from parsers.github import process_github
from utils.file import convert_bytes
@ -32,8 +31,6 @@ async def crawl_endpoint(
# [TODO] check if the user is the owner/editor of the brain
brain = Brain(id=brain_id)
commons = common_dependencies()
if request.headers.get("Openai-Api-Key"):
brain.max_brain_size = os.getenv(
"MAX_BRAIN_SIZE_WITH_KEY", 209715200
@ -66,19 +63,17 @@ async def crawl_endpoint(
file = File(file=uploadFile)
# check remaining free space here !!
message = await filter_file(
commons,
file,
enable_summarization,
brain.id,
file=file,
enable_summarization=enable_summarization,
brain_id=brain.id,
openai_api_key=request.headers.get("Openai-Api-Key", None),
)
return message
else:
# check remaining free space here !!
message = await process_github(
commons,
crawl_website.url,
"false",
brain_id,
repo=crawl_website.url,
enable_summarization="false",
brain_id=brain_id,
user_openai_api_key=request.headers.get("Openai-Api-Key", None),
)

View File

@ -3,9 +3,8 @@ from uuid import UUID
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query
from models.brains import Brain
from models.settings import common_dependencies
from models.settings import get_supabase_db
from models.users import User
from routes.authorizations.brain_authorization import (
RoleEnum,
has_brain_authorization,
@ -64,8 +63,8 @@ async def download_endpoint(
"""
# check if user has the right to get the file: add brain_id to the query
commons = common_dependencies()
response = commons["db"].get_vectors_by_file_name(file_name)
supabase_db = get_supabase_db()
response = supabase_db.get_vectors_by_file_name(file_name)
documents = response.data
if len(documents) == 0:

View File

@ -5,17 +5,15 @@ from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query, Request, UploadFile
from models.brains import Brain
from models.files import File
from models.settings import common_dependencies
from models.users import User
from repository.brain.get_brain_details import get_brain_details
from repository.user_identity.get_user_identity import get_user_identity
from utils.file import convert_bytes, get_file_size
from utils.processors import filter_file
from routes.authorizations.brain_authorization import (
RoleEnum,
validate_brain_authorization,
)
from utils.file import convert_bytes, get_file_size
from utils.processors import filter_file
upload_router = APIRouter()
@ -45,7 +43,6 @@ async def upload_file(
)
brain = Brain(id=brain_id)
commons = common_dependencies()
if request.headers.get("Openai-Api-Key"):
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
@ -71,9 +68,8 @@ async def upload_file(
openai_api_key = get_user_identity(current_user.id).openai_api_key
message = await filter_file(
commons,
file,
enable_summarization,
file=file,
enable_summarization=enable_summarization,
brain_id=brain_id,
openai_api_key=openai_api_key,
)

View File

@ -1,6 +1,5 @@
from models.brains import Brain
from models.files import File
from models.settings import CommonsDep
from parsers.audio import process_audio
from parsers.csv import process_csv
from parsers.docx import process_docx
@ -40,7 +39,6 @@ def create_response(message, type):
async def filter_file(
commons: CommonsDep,
file: File,
enable_summarization: bool,
brain_id,
@ -72,7 +70,10 @@ async def filter_file(
if file.file_extension in file_processors:
try:
await file_processors[file.file_extension](
commons, file, enable_summarization, brain_id, openai_api_key
file=file,
enable_summarization=enable_summarization,
brain_id=brain_id,
user_openai_api_key=openai_api_key,
)
return create_response(
f"{file.file.filename} has been uploaded to brain {brain_id}.", # pyright: ignore reportPrivateUsage=none

View File

@ -1,13 +0,0 @@
from logger import get_logger
from models.settings import CommonsDep
from models.users import User
logger = get_logger(__name__)
def create_user(commons: CommonsDep, user: User, date):
logger.info(f"New user entry in db document for user {user.email}")
return (
commons["db"].create_user(user.id, user.email, date)
)

View File

@ -3,25 +3,23 @@ from typing import List
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
from models.settings import BrainSettings, CommonsDep, common_dependencies
from models.settings import get_documents_vector_store, get_embeddings, get_supabase_db
from pydantic import BaseModel
logger = get_logger(__name__)
class Neurons(BaseModel):
commons: CommonsDep
settings = BrainSettings() # pyright: ignore reportPrivateUsage=none
def create_vector(self, doc, user_openai_api_key=None):
documents_vector_store = get_documents_vector_store()
logger.info("Creating vector for document")
logger.info(f"Document: {doc}")
if user_openai_api_key:
self.commons["documents_vector_store"]._embedding = OpenAIEmbeddings(
documents_vector_store._embedding = OpenAIEmbeddings(
openai_api_key=user_openai_api_key
) # pyright: ignore reportPrivateUsage=none
try:
sids = self.commons["documents_vector_store"].add_documents([doc])
sids = documents_vector_store.add_documents([doc])
if sids and len(sids) > 0:
return sids
@ -29,11 +27,15 @@ class Neurons(BaseModel):
logger.error(f"Error creating vector for document {e}")
def create_embedding(self, content):
return self.commons["embeddings"].embed_query(content)
embeddings = get_embeddings()
return embeddings.embed_query(content)
def similarity_search(self, query, table="match_summaries", top_k=5, threshold=0.5):
query_embedding = self.create_embedding(query)
summaries = self.commons["db"].similarity_search(query_embedding, table, top_k, threshold)
supabase_db = get_supabase_db()
summaries = supabase_db.similarity_search(
query_embedding, table, top_k, threshold
)
return summaries.data
@ -42,13 +44,13 @@ def error_callback(exception):
def process_batch(batch_ids: List[str]):
commons = common_dependencies()
db = commons["db"]
supabase_db = get_supabase_db()
try:
if len(batch_ids) == 1:
return (db.get_vectors_by_batch(batch_ids[0])).data
return (supabase_db.get_vectors_by_batch(batch_ids[0])).data
else:
return (db.get_vectors_in_batch(batch_ids)).data
return (supabase_db.get_vectors_in_batch(batch_ids)).data
except Exception as e:
logger.error("Error retrieving batched vectors", e)