diff --git a/.gitignore b/.gitignore
index 112174048..c034793f4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -102,3 +102,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
# Tox
.tox
Pipfile
+*.pkl
diff --git a/backend/api/pyproject.toml b/backend/api/pyproject.toml
index 62ce71e6f..f8873a15a 100644
--- a/backend/api/pyproject.toml
+++ b/backend/api/pyproject.toml
@@ -19,7 +19,6 @@ dependencies = [
"pydantic-settings>=2.4.0",
"python-dotenv>=1.0.1",
"unidecode>=1.3.8",
- "fpdf>=1.7.2",
"colorlog>=6.8.2",
"posthog>=3.5.0",
"pyinstrument>=4.7.2",
diff --git a/backend/api/quivr_api/logger.py b/backend/api/quivr_api/logger.py
index 0fc69cfce..b839e9aef 100644
--- a/backend/api/quivr_api/logger.py
+++ b/backend/api/quivr_api/logger.py
@@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler
from colorlog import (
ColoredFormatter,
-) # You need to install this package: pip install colorlog
+)
def get_logger(logger_name, log_file="application.log"):
diff --git a/backend/api/quivr_api/middlewares/auth/jwt_token_handler.py b/backend/api/quivr_api/middlewares/auth/jwt_token_handler.py
index 4e412b41d..e6438d2ac 100644
--- a/backend/api/quivr_api/middlewares/auth/jwt_token_handler.py
+++ b/backend/api/quivr_api/middlewares/auth/jwt_token_handler.py
@@ -4,6 +4,7 @@ from typing import Optional
from jose import jwt
from jose.exceptions import JWTError
+
from quivr_api.modules.user.entity.user_identity import UserIdentity
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
diff --git a/backend/api/quivr_api/models/brains_subscription_invitations.py b/backend/api/quivr_api/models/brains_subscription_invitations.py
index bbe474c22..fc0ddf048 100644
--- a/backend/api/quivr_api/models/brains_subscription_invitations.py
+++ b/backend/api/quivr_api/models/brains_subscription_invitations.py
@@ -1,6 +1,7 @@
from uuid import UUID
from pydantic import BaseModel, ConfigDict
+
from quivr_api.logger import get_logger
logger = get_logger(__name__)
diff --git a/backend/api/quivr_api/modules/analytics/controller/analytics_routes.py b/backend/api/quivr_api/modules/analytics/controller/analytics_routes.py
index dfb889b38..4dc924032 100644
--- a/backend/api/quivr_api/modules/analytics/controller/analytics_routes.py
+++ b/backend/api/quivr_api/modules/analytics/controller/analytics_routes.py
@@ -1,6 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, Depends, Query
+
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
from quivr_api.modules.analytics.entity.analytics import Range
from quivr_api.modules.analytics.service.analytics_service import AnalyticsService
diff --git a/backend/api/quivr_api/modules/analytics/entity/analytics.py b/backend/api/quivr_api/modules/analytics/entity/analytics.py
index 4b3424b02..6d8dced41 100644
--- a/backend/api/quivr_api/modules/analytics/entity/analytics.py
+++ b/backend/api/quivr_api/modules/analytics/entity/analytics.py
@@ -1,16 +1,20 @@
+from datetime import date
from enum import IntEnum
from typing import List
+
from pydantic import BaseModel
-from datetime import date
+
class Range(IntEnum):
WEEK = 7
MONTH = 30
QUARTER = 90
+
class Usage(BaseModel):
date: date
usage_count: int
+
class BrainsUsages(BaseModel):
- usages: List[Usage]
\ No newline at end of file
+ usages: List[Usage]
diff --git a/backend/api/quivr_api/modules/analytics/service/analytics_service.py b/backend/api/quivr_api/modules/analytics/service/analytics_service.py
index c7a60eaf3..34a90f325 100644
--- a/backend/api/quivr_api/modules/analytics/service/analytics_service.py
+++ b/backend/api/quivr_api/modules/analytics/service/analytics_service.py
@@ -11,5 +11,4 @@ class AnalyticsService:
self.repository = Analytics()
def get_brains_usages(self, user_id, graph_range, brain_id=None):
-
return self.repository.get_brains_usages(user_id, graph_range, brain_id)
diff --git a/backend/api/quivr_api/modules/api_key/controller/api_key_routes.py b/backend/api/quivr_api/modules/api_key/controller/api_key_routes.py
index dedf9f704..7215ae76d 100644
--- a/backend/api/quivr_api/modules/api_key/controller/api_key_routes.py
+++ b/backend/api/quivr_api/modules/api_key/controller/api_key_routes.py
@@ -3,6 +3,7 @@ from typing import List
from uuid import uuid4
from fastapi import APIRouter, Depends
+
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.api_key.dto.outputs import ApiKeyInfo
diff --git a/backend/api/quivr_api/modules/api_key/service/api_key_service.py b/backend/api/quivr_api/modules/api_key/service/api_key_service.py
index f459eff38..9290a51e4 100644
--- a/backend/api/quivr_api/modules/api_key/service/api_key_service.py
+++ b/backend/api/quivr_api/modules/api_key/service/api_key_service.py
@@ -1,6 +1,7 @@
from datetime import datetime
from fastapi import HTTPException
+
from quivr_api.logger import get_logger
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
from quivr_api.modules.api_key.repository.api_keys import ApiKeys
diff --git a/backend/api/quivr_api/modules/assistant/controller/__init__.py b/backend/api/quivr_api/modules/assistant/controller/__init__.py
index 64ec4d89f..cc8eb3907 100644
--- a/backend/api/quivr_api/modules/assistant/controller/__init__.py
+++ b/backend/api/quivr_api/modules/assistant/controller/__init__.py
@@ -1 +1,6 @@
+# noqa:
from .assistant_routes import assistant_router
+
+__all__ = [
+ "assistant_router",
+]
diff --git a/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py b/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py
index 3f5605a14..9d2e303bb 100644
--- a/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py
+++ b/backend/api/quivr_api/modules/assistant/controller/assistant_routes.py
@@ -1,63 +1,176 @@
-from typing import List
+import io
+from typing import Annotated, List
+from uuid import uuid4
-from fastapi import APIRouter, Depends, HTTPException, UploadFile
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
+
+from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
-from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.assistant.dto.inputs import InputAssistant
+from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
+from quivr_api.modules.assistant.controller.assistants_definition import (
+ assistants,
+ validate_assistant_input,
+)
+from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant
from quivr_api.modules.assistant.dto.outputs import AssistantOutput
-from quivr_api.modules.assistant.ito.difference import DifferenceAssistant
-from quivr_api.modules.assistant.ito.summary import SummaryAssistant, summary_inputs
-from quivr_api.modules.assistant.service.assistant import Assistant
+from quivr_api.modules.assistant.entity.assistant_entity import (
+ AssistantSettings,
+)
+from quivr_api.modules.assistant.services.tasks_service import TasksService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.upload.service.upload_file import (
+ upload_file_storage,
+)
from quivr_api.modules.user.entity.user_identity import UserIdentity
-assistant_router = APIRouter()
logger = get_logger(__name__)
-assistant_service = Assistant()
+
+assistant_router = APIRouter()
+
+
+TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))]
+UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
@assistant_router.get(
"/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
)
-async def list_assistants(
+async def get_assistants(
+ request: Request,
current_user: UserIdentity = Depends(get_current_user),
) -> List[AssistantOutput]:
- """
- Retrieve and list all the knowledge in a brain.
- """
+ logger.info("Getting assistants")
- summary = summary_inputs()
- # difference = difference_inputs()
- # crawler = crawler_inputs()
- return [summary]
+ return assistants
+
+
+@assistant_router.get(
+ "/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
+)
+async def get_tasks(
+ request: Request,
+ current_user: UserIdentityDep,
+ tasks_service: TasksServiceDep,
+):
+ logger.info("Getting tasks")
+ return await tasks_service.get_tasks_by_user_id(current_user.id)
@assistant_router.post(
- "/assistant/process",
+ "/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
+)
+async def create_task(
+ current_user: UserIdentityDep,
+ tasks_service: TasksServiceDep,
+ request: Request,
+ input: InputAssistant,
+ files: List[UploadFile] = None,
+):
+ assistant = next(
+ (assistant for assistant in assistants if assistant.id == input.id), None
+ )
+ if assistant is None:
+ raise HTTPException(status_code=404, detail="Assistant not found")
+
+ is_valid, validation_errors = validate_assistant_input(input, assistant)
+ if not is_valid:
+ for error in validation_errors:
+ print(error)
+ raise HTTPException(status_code=400, detail=error)
+ else:
+ print("Assistant input is valid.")
+ notification_uuid = uuid4()
+
+ # Process files dynamically
+ for upload_file in files:
+ file_name_path = f"{input.id}/{notification_uuid}/{upload_file.filename}"
+ buff_reader = io.BufferedReader(upload_file.file) # type: ignore
+ try:
+ await upload_file_storage(buff_reader, file_name_path)
+ except Exception as e:
+ logger.exception(f"Exception in upload_route {e}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to upload file to storage. {e}"
+ )
+
+ task = CreateTask(
+ assistant_id=input.id,
+ pretty_id=str(notification_uuid),
+ settings=input.model_dump(mode="json"),
+ )
+
+ task_created = await tasks_service.create_task(task, current_user.id)
+
+ celery.send_task(
+ "process_assistant_task",
+ kwargs={
+ "assistant_id": input.id,
+ "notification_uuid": notification_uuid,
+ "task_id": task_created.id,
+ "user_id": str(current_user.id),
+ },
+ )
+ return task_created
+
+
+@assistant_router.get(
+ "/assistants/task/{task_id}",
dependencies=[Depends(AuthBearer())],
tags=["Assistant"],
)
-async def process_assistant(
- input: InputAssistant,
- files: List[UploadFile] = None,
- current_user: UserIdentity = Depends(get_current_user),
+async def get_task(
+ request: Request,
+ task_id: str,
+ current_user: UserIdentityDep,
+ tasks_service: TasksServiceDep,
):
- if input.name.lower() == "summary":
- summary_assistant = SummaryAssistant(
- input=input, files=files, current_user=current_user
- )
- try:
- summary_assistant.check_input()
- return await summary_assistant.process_assistant()
- except ValueError as e:
- raise HTTPException(status_code=400, detail=str(e))
- elif input.name.lower() == "difference":
- difference_assistant = DifferenceAssistant(
- input=input, files=files, current_user=current_user
- )
- try:
- difference_assistant.check_input()
- return await difference_assistant.process_assistant()
- except ValueError as e:
- raise HTTPException(status_code=400, detail=str(e))
- return {"message": "Assistant not found"}
+ return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore
+
+
+@assistant_router.delete(
+ "/assistants/task/{task_id}",
+ dependencies=[Depends(AuthBearer())],
+ tags=["Assistant"],
+)
+async def delete_task(
+ request: Request,
+ task_id: int,
+ current_user: UserIdentityDep,
+ tasks_service: TasksServiceDep,
+):
+ return await tasks_service.delete_task(task_id, current_user.id)
+
+
+@assistant_router.get(
+ "/assistants/task/{task_id}/download",
+ dependencies=[Depends(AuthBearer())],
+ tags=["Assistant"],
+)
+async def get_download_link_task(
+ request: Request,
+ task_id: int,
+ current_user: UserIdentityDep,
+ tasks_service: TasksServiceDep,
+):
+ return await tasks_service.get_download_link_task(task_id, current_user.id)
+
+
+@assistant_router.get(
+ "/assistants/{assistant_id}/config",
+ dependencies=[Depends(AuthBearer())],
+ tags=["Assistant"],
+ response_model=AssistantSettings,
+ summary="Retrieve assistant configuration",
+ description="Get the settings and file requirements for the specified assistant.",
+)
+async def get_assistant_config(
+ assistant_id: int,
+ current_user: UserIdentityDep,
+):
+ assistant = next(
+ (assistant for assistant in assistants if assistant.id == assistant_id), None
+ )
+ if assistant is None:
+ raise HTTPException(status_code=404, detail="Assistant not found")
+ return assistant.settings
diff --git a/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py b/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py
new file mode 100644
index 000000000..0ade87168
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/controller/assistants_definition.py
@@ -0,0 +1,201 @@
+from quivr_api.modules.assistant.dto.inputs import InputAssistant
+from quivr_api.modules.assistant.dto.outputs import (
+ AssistantOutput,
+ InputFile,
+ Inputs,
+ Pricing,
+)
+
+
+def validate_assistant_input(
+ assistant_input: InputAssistant, assistant_output: AssistantOutput
+):
+ errors = []
+
+ # Validate files
+ if assistant_output.inputs.files:
+ required_files = [
+ file for file in assistant_output.inputs.files if file.required
+ ]
+ input_files = {
+ file_input.key for file_input in (assistant_input.inputs.files or [])
+ }
+ for req_file in required_files:
+ if req_file.key not in input_files:
+ errors.append(f"Missing required file input: {req_file.key}")
+
+ # Validate URLs
+ if assistant_output.inputs.urls:
+ required_urls = [url for url in assistant_output.inputs.urls if url.required]
+ input_urls = {
+ url_input.key for url_input in (assistant_input.inputs.urls or [])
+ }
+ for req_url in required_urls:
+ if req_url.key not in input_urls:
+ errors.append(f"Missing required URL input: {req_url.key}")
+
+ # Validate texts
+ if assistant_output.inputs.texts:
+ required_texts = [
+ text for text in assistant_output.inputs.texts if text.required
+ ]
+ input_texts = {
+ text_input.key for text_input in (assistant_input.inputs.texts or [])
+ }
+ for req_text in required_texts:
+ if req_text.key not in input_texts:
+ errors.append(f"Missing required text input: {req_text.key}")
+ else:
+ # Validate regex if applicable
+ req_text_val = next(
+ (t for t in assistant_output.inputs.texts if t.key == req_text.key),
+ None,
+ )
+ if req_text_val and req_text_val.validation_regex:
+ import re
+
+ input_value = next(
+ (
+ t.value
+ for t in assistant_input.inputs.texts
+ if t.key == req_text.key
+ ),
+ "",
+ )
+ if not re.match(req_text_val.validation_regex, input_value):
+ errors.append(
+ f"Text input '{req_text.key}' does not match the required format."
+ )
+
+ # Validate booleans
+ if assistant_output.inputs.booleans:
+ required_booleans = [b for b in assistant_output.inputs.booleans if b.required]
+ input_booleans = {
+ b_input.key for b_input in (assistant_input.inputs.booleans or [])
+ }
+ for req_bool in required_booleans:
+ if req_bool.key not in input_booleans:
+ errors.append(f"Missing required boolean input: {req_bool.key}")
+
+ # Validate numbers
+ if assistant_output.inputs.numbers:
+ required_numbers = [n for n in assistant_output.inputs.numbers if n.required]
+ input_numbers = {
+ n_input.key for n_input in (assistant_input.inputs.numbers or [])
+ }
+ for req_number in required_numbers:
+ if req_number.key not in input_numbers:
+ errors.append(f"Missing required number input: {req_number.key}")
+ else:
+ # Validate min and max
+ input_value = next(
+ (
+ n.value
+ for n in assistant_input.inputs.numbers
+ if n.key == req_number.key
+ ),
+ None,
+ )
+ if req_number.min is not None and input_value < req_number.min:
+ errors.append(
+ f"Number input '{req_number.key}' is below minimum value."
+ )
+ if req_number.max is not None and input_value > req_number.max:
+ errors.append(
+ f"Number input '{req_number.key}' exceeds maximum value."
+ )
+
+ # Validate select_texts
+ if assistant_output.inputs.select_texts:
+ required_select_texts = [
+ st for st in assistant_output.inputs.select_texts if st.required
+ ]
+ input_select_texts = {
+ st_input.key for st_input in (assistant_input.inputs.select_texts or [])
+ }
+ for req_select in required_select_texts:
+ if req_select.key not in input_select_texts:
+ errors.append(f"Missing required select text input: {req_select.key}")
+ else:
+ input_value = next(
+ (
+ st.value
+ for st in assistant_input.inputs.select_texts
+ if st.key == req_select.key
+ ),
+ None,
+ )
+ if input_value not in req_select.options:
+ errors.append(f"Invalid option for select text '{req_select.key}'.")
+
+ # Validate select_numbers
+ if assistant_output.inputs.select_numbers:
+ required_select_numbers = [
+ sn for sn in assistant_output.inputs.select_numbers if sn.required
+ ]
+ input_select_numbers = {
+ sn_input.key for sn_input in (assistant_input.inputs.select_numbers or [])
+ }
+ for req_select in required_select_numbers:
+ if req_select.key not in input_select_numbers:
+ errors.append(f"Missing required select number input: {req_select.key}")
+ else:
+ input_value = next(
+ (
+ sn.value
+ for sn in assistant_input.inputs.select_numbers
+ if sn.key == req_select.key
+ ),
+ None,
+ )
+ if input_value not in req_select.options:
+ errors.append(
+ f"Invalid option for select number '{req_select.key}'."
+ )
+
+ # Validate brain input
+ if assistant_output.inputs.brain and assistant_output.inputs.brain.required:
+ if not assistant_input.inputs.brain or not assistant_input.inputs.brain.value:
+ errors.append("Missing required brain input.")
+
+ if errors:
+ return False, errors
+ else:
+ return True, None
+
+
+assistant1 = AssistantOutput(
+ id=1,
+ name="Assistant 1",
+ description="Assistant 1 description",
+ pricing=Pricing(),
+ tags=["tag1", "tag2"],
+ input_description="Input description",
+ output_description="Output description",
+ inputs=Inputs(
+ files=[
+ InputFile(key="file_1", description="File description"),
+ InputFile(key="file_2", description="File description"),
+ ],
+ ),
+ icon_url="https://example.com/icon.png",
+)
+
+assistant2 = AssistantOutput(
+ id=2,
+ name="Assistant 2",
+ description="Assistant 2 description",
+ pricing=Pricing(),
+ tags=["tag1", "tag2"],
+ input_description="Input description",
+ output_description="Output description",
+ icon_url="https://example.com/icon.png",
+ inputs=Inputs(
+ files=[
+ InputFile(key="file_1", description="File description"),
+ InputFile(key="file_2", description="File description"),
+ ],
+ ),
+)
+
+assistants = [assistant1, assistant2]
diff --git a/backend/api/quivr_api/modules/assistant/dto/inputs.py b/backend/api/quivr_api/modules/assistant/dto/inputs.py
index 631f3e4fe..929f95535 100644
--- a/backend/api/quivr_api/modules/assistant/dto/inputs.py
+++ b/backend/api/quivr_api/modules/assistant/dto/inputs.py
@@ -1,16 +1,16 @@
-import json
from typing import List, Optional
from uuid import UUID
-from pydantic import BaseModel, model_validator, root_validator
+from pydantic import BaseModel, root_validator
-class EmailInput(BaseModel):
- activated: bool
+class CreateTask(BaseModel):
+ pretty_id: str
+ assistant_id: int
+ settings: dict
class BrainInput(BaseModel):
- activated: Optional[bool] = False
value: Optional[UUID] = None
@root_validator(pre=True)
@@ -64,19 +64,10 @@ class Inputs(BaseModel):
numbers: Optional[List[InputNumber]] = None
select_texts: Optional[List[InputSelectText]] = None
select_numbers: Optional[List[InputSelectNumber]] = None
-
-
-class Outputs(BaseModel):
- email: Optional[EmailInput] = None
brain: Optional[BrainInput] = None
class InputAssistant(BaseModel):
+ id: int
name: str
inputs: Inputs
- outputs: Outputs
-
- @model_validator(mode="before")
- @classmethod
- def to_py_dict(cls, data):
- return json.loads(data)
diff --git a/backend/api/quivr_api/modules/assistant/dto/outputs.py b/backend/api/quivr_api/modules/assistant/dto/outputs.py
index cf6398ad8..40574e5bf 100644
--- a/backend/api/quivr_api/modules/assistant/dto/outputs.py
+++ b/backend/api/quivr_api/modules/assistant/dto/outputs.py
@@ -3,6 +3,12 @@ from typing import List, Optional
from pydantic import BaseModel
+class BrainInput(BaseModel):
+ required: Optional[bool] = True
+ description: str
+ type: str
+
+
class InputFile(BaseModel):
key: str
allowed_extensions: Optional[List[str]] = ["pdf"]
@@ -63,23 +69,7 @@ class Inputs(BaseModel):
numbers: Optional[List[InputNumber]] = None
select_texts: Optional[List[InputSelectText]] = None
select_numbers: Optional[List[InputSelectNumber]] = None
-
-
-class OutputEmail(BaseModel):
- required: Optional[bool] = True
- description: str
- type: str
-
-
-class OutputBrain(BaseModel):
- required: Optional[bool] = True
- description: str
- type: str
-
-
-class Outputs(BaseModel):
- email: Optional[OutputEmail] = None
- brain: Optional[OutputBrain] = None
+ brain: Optional[BrainInput] = None
class Pricing(BaseModel):
@@ -88,6 +78,7 @@ class Pricing(BaseModel):
class AssistantOutput(BaseModel):
+ id: int
name: str
description: str
pricing: Optional[Pricing] = Pricing()
@@ -95,5 +86,4 @@ class AssistantOutput(BaseModel):
input_description: str
output_description: str
inputs: Inputs
- outputs: Outputs
icon_url: Optional[str] = None
diff --git a/backend/api/quivr_api/modules/assistant/entity/__init__.py b/backend/api/quivr_api/modules/assistant/entity/__init__.py
index b848b1efc..e69de29bb 100644
--- a/backend/api/quivr_api/modules/assistant/entity/__init__.py
+++ b/backend/api/quivr_api/modules/assistant/entity/__init__.py
@@ -1 +0,0 @@
-from .assistant import AssistantEntity
diff --git a/backend/api/quivr_api/modules/assistant/entity/assistant.py b/backend/api/quivr_api/modules/assistant/entity/assistant.py
deleted file mode 100644
index 00bcac691..000000000
--- a/backend/api/quivr_api/modules/assistant/entity/assistant.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from uuid import UUID
-
-from pydantic import BaseModel
-
-
-class AssistantEntity(BaseModel):
- id: UUID
- name: str
- brain_id_required: bool
- file_1_required: bool
- url_required: bool
diff --git a/backend/api/quivr_api/modules/assistant/entity/assistant_entity.py b/backend/api/quivr_api/modules/assistant/entity/assistant_entity.py
new file mode 100644
index 000000000..6321ff1f4
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/entity/assistant_entity.py
@@ -0,0 +1,33 @@
+from typing import Any, List, Optional
+
+from pydantic import BaseModel
+
+
+class AssistantFileRequirement(BaseModel):
+ name: str
+ description: Optional[str] = None
+ required: bool = True
+ accepted_types: Optional[List[str]] = None # e.g., ['text/csv', 'application/json']
+
+
+class AssistantInput(BaseModel):
+ name: str
+ description: str
+ type: str # e.g., 'boolean', 'uuid', 'string'
+ required: bool = True
+ regex: Optional[str] = None
+ options: Optional[List[Any]] = None # For predefined choices
+ default: Optional[Any] = None
+
+
+class AssistantSettings(BaseModel):
+ inputs: List[AssistantInput]
+ files: Optional[List[AssistantFileRequirement]] = None
+
+
+class Assistant(BaseModel):
+ id: int
+ name: str
+ description: str
+ settings: AssistantSettings
+ required_files: Optional[List[str]] = None # List of required file names
diff --git a/backend/api/quivr_api/modules/assistant/entity/task_entity.py b/backend/api/quivr_api/modules/assistant/entity/task_entity.py
new file mode 100644
index 000000000..01d5f33b2
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/entity/task_entity.py
@@ -0,0 +1,34 @@
+from datetime import datetime
+from typing import Dict
+from uuid import UUID
+
+from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text
+
+
+class Task(SQLModel, table=True):
+ __tablename__ = "tasks" # type: ignore
+
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(
+ BigInteger,
+ primary_key=True,
+ autoincrement=True,
+ ),
+ )
+ assistant_id: int
+ pretty_id: str
+ user_id: UUID
+ status: str = Field(default="pending")
+ creation_time: datetime | None = Field(
+ default=None,
+ sa_column=Column(
+ TIMESTAMP(timezone=False),
+ server_default=text("CURRENT_TIMESTAMP"),
+ ),
+ )
+ settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
+ answer: str | None = Field(default=None)
+
+ class Config:
+ arbitrary_types_allowed = True
diff --git a/backend/api/quivr_api/modules/assistant/ito/crawler.py b/backend/api/quivr_api/modules/assistant/ito/crawler.py
deleted file mode 100644
index c944c0895..000000000
--- a/backend/api/quivr_api/modules/assistant/ito/crawler.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from bs4 import BeautifulSoup as Soup
-from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
-from quivr_api.logger import get_logger
-from quivr_api.modules.assistant.dto.outputs import (
- AssistantOutput,
- Inputs,
- InputUrl,
- OutputBrain,
- OutputEmail,
- Outputs,
-)
-from quivr_api.modules.assistant.ito.ito import ITO
-
-logger = get_logger(__name__)
-
-
-class CrawlerAssistant(ITO):
-
- def __init__(
- self,
- **kwargs,
- ):
- super().__init__(
- **kwargs,
- )
-
- async def process_assistant(self):
-
- url = self.url
- loader = RecursiveUrlLoader(
- url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
- )
- docs = loader.load()
-
- nice_url = url.split("://")[1].replace("/", "_").replace(".", "_")
- nice_url += ".txt"
-
- for docs in docs:
- await self.create_and_upload_processed_file(
- docs.page_content, nice_url, "Crawler"
- )
-
-
-def crawler_inputs():
- output = AssistantOutput(
- name="Crawler",
- description="Crawls a website and extracts the text from the pages",
- tags=["new"],
- input_description="One URL to crawl",
- output_description="Text extracted from the pages",
- inputs=Inputs(
- urls=[
- InputUrl(
- key="url",
- required=True,
- description="The URL to crawl",
- )
- ],
- ),
- outputs=Outputs(
- brain=OutputBrain(
- required=True,
- description="The brain to which upload the document",
- type="uuid",
- ),
- email=OutputEmail(
- required=True,
- description="Send the document by email",
- type="str",
- ),
- ),
- )
- return output
diff --git a/backend/api/quivr_api/modules/assistant/ito/difference.py b/backend/api/quivr_api/modules/assistant/ito/difference.py
deleted file mode 100644
index fd19d0648..000000000
--- a/backend/api/quivr_api/modules/assistant/ito/difference.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import os
-import tempfile
-from typing import List
-
-from fastapi import UploadFile
-from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
-from langchain_community.chat_models import ChatLiteLLM
-from langchain_core.output_parsers import StrOutputParser
-from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
-from llama_parse import LlamaParse
-from quivr_api.logger import get_logger
-from quivr_api.modules.assistant.dto.inputs import InputAssistant
-from quivr_api.modules.assistant.dto.outputs import (
- AssistantOutput,
- InputFile,
- Inputs,
- OutputBrain,
- OutputEmail,
- Outputs,
-)
-from quivr_api.modules.assistant.ito.ito import ITO
-from quivr_api.modules.user.entity.user_identity import UserIdentity
-
-logger = get_logger(__name__)
-
-
-class DifferenceAssistant(ITO):
-
- def __init__(
- self,
- input: InputAssistant,
- files: List[UploadFile] = None,
- current_user: UserIdentity = None,
- **kwargs,
- ):
- super().__init__(
- input=input,
- files=files,
- current_user=current_user,
- **kwargs,
- )
-
- def check_input(self):
- if not self.files:
- raise ValueError("No file was uploaded")
- if len(self.files) != 2:
- raise ValueError("Only two files can be uploaded")
- if not self.input.inputs.files:
- raise ValueError("No files key were given in the input")
- if len(self.input.inputs.files) != 2:
- raise ValueError("Only two files can be uploaded")
- if not self.input.inputs.files[0].key == "doc_1":
- raise ValueError("The key of the first file should be doc_1")
- if not self.input.inputs.files[1].key == "doc_2":
- raise ValueError("The key of the second file should be doc_2")
- if not self.input.inputs.files[0].value:
- raise ValueError("No file was uploaded")
- if not (
- self.input.outputs.brain.activated or self.input.outputs.email.activated
- ):
- raise ValueError("No output was selected")
- return True
-
- async def process_assistant(self):
-
- document_1 = self.files[0]
- document_2 = self.files[1]
-
- # Get the file extensions
- document_1_ext = os.path.splitext(document_1.filename)[1]
- document_2_ext = os.path.splitext(document_2.filename)[1]
-
- # Create temporary files with the same extension as the original files
- document_1_tmp = tempfile.NamedTemporaryFile(
- suffix=document_1_ext, delete=False
- )
- document_2_tmp = tempfile.NamedTemporaryFile(
- suffix=document_2_ext, delete=False
- )
-
- document_1_tmp.write(document_1.file.read())
- document_2_tmp.write(document_2.file.read())
-
- parser = LlamaParse(
- result_type="markdown" # "markdown" and "text" are available
- )
-
- document_1_llama_parsed = parser.load_data(document_1_tmp.name)
- document_2_llama_parsed = parser.load_data(document_2_tmp.name)
-
- document_1_tmp.close()
- document_2_tmp.close()
-
- document_1_to_langchain = document_1_llama_parsed[0].to_langchain_format()
- document_2_to_langchain = document_2_llama_parsed[0].to_langchain_format()
-
- llm = ChatLiteLLM(model="gpt-4o")
-
- human_prompt = """Given the following two documents, find the difference between them:
-
- Document 1:
- {document_1}
- Document 2:
- {document_2}
- Difference:
- """
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(human_prompt)
-
- system_message_template = """
- You are an expert in finding the difference between two documents. You look deeply into what makes the two documents different and provide a detailed analysis if needed of the differences between the two documents.
- If no differences are found, simply say that there are no differences.
- """
-
- ANSWER_PROMPT = ChatPromptTemplate.from_messages(
- [
- SystemMessagePromptTemplate.from_template(system_message_template),
- HumanMessagePromptTemplate.from_template(human_prompt),
- ]
- )
-
- final_inputs = {
- "document_1": document_1_to_langchain.page_content,
- "document_2": document_2_to_langchain.page_content,
- }
-
- output_parser = StrOutputParser()
-
- chain = ANSWER_PROMPT | llm | output_parser
- result = chain.invoke(final_inputs)
-
- return result
-
-
-def difference_inputs():
- output = AssistantOutput(
- name="difference",
- description="Finds the difference between two sets of documents",
- tags=["new"],
- input_description="Two documents to compare",
- output_description="The difference between the two documents",
- icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
- inputs=Inputs(
- files=[
- InputFile(
- key="doc_1",
- allowed_extensions=["pdf"],
- required=True,
- description="The first document to compare",
- ),
- InputFile(
- key="doc_2",
- allowed_extensions=["pdf"],
- required=True,
- description="The second document to compare",
- ),
- ]
- ),
- outputs=Outputs(
- brain=OutputBrain(
- required=True,
- description="The brain to which upload the document",
- type="uuid",
- ),
- email=OutputEmail(
- required=True,
- description="Send the document by email",
- type="str",
- ),
- ),
- )
- return output
diff --git a/backend/api/quivr_api/modules/assistant/ito/ito.py b/backend/api/quivr_api/modules/assistant/ito/ito.py
deleted file mode 100644
index 39d4d02f6..000000000
--- a/backend/api/quivr_api/modules/assistant/ito/ito.py
+++ /dev/null
@@ -1,196 +0,0 @@
-import os
-import random
-import re
-import string
-from abc import abstractmethod
-from io import BytesIO
-from tempfile import NamedTemporaryFile
-from typing import List, Optional
-
-from fastapi import UploadFile
-from pydantic import BaseModel
-from unidecode import unidecode
-
-from quivr_api.logger import get_logger
-from quivr_api.models.settings import SendEmailSettings
-from quivr_api.modules.assistant.dto.inputs import InputAssistant
-from quivr_api.modules.assistant.ito.utils.pdf_generator import PDFGenerator, PDFModel
-from quivr_api.modules.chat.controller.chat.utils import update_user_usage
-from quivr_api.modules.upload.controller.upload_routes import upload_file
-from quivr_api.modules.user.entity.user_identity import UserIdentity
-from quivr_api.modules.user.service.user_usage import UserUsage
-from quivr_api.utils.send_email import send_email
-
-logger = get_logger(__name__)
-
-
-class ITO(BaseModel):
- input: InputAssistant
- files: List[UploadFile]
- current_user: UserIdentity
- user_usage: Optional[UserUsage] = None
- user_settings: Optional[dict] = None
-
- def __init__(
- self,
- input: InputAssistant,
- files: List[UploadFile] = None,
- current_user: UserIdentity = None,
- **kwargs,
- ):
- super().__init__(
- input=input,
- files=files,
- current_user=current_user,
- **kwargs,
- )
- self.user_usage = UserUsage(
- id=current_user.id,
- email=current_user.email,
- )
- self.user_settings = self.user_usage.get_user_settings()
- self.increase_usage_user()
-
- def increase_usage_user(self):
- # Raises an error if the user has consumed all of of his credits
-
- update_user_usage(
- usage=self.user_usage,
- user_settings=self.user_settings,
- cost=self.calculate_pricing(),
- )
-
- def calculate_pricing(self):
- return 20
-
- def generate_pdf(self, filename: str, title: str, content: str):
- pdf_model = PDFModel(title=title, content=content)
- pdf = PDFGenerator(pdf_model)
- pdf.print_pdf()
- pdf.output(filename, "F")
-
- @abstractmethod
- async def process_assistant(self):
- pass
-
- async def send_output_by_email(
- self,
- file: UploadFile,
- filename: str,
- task_name: str,
- custom_message: str,
- brain_id: str = None,
- ):
- settings = SendEmailSettings()
- file = await self.uploadfile_to_file(file)
- domain_quivr = os.getenv("QUIVR_DOMAIN", "https://chat.quivr.app/")
-
- with open(file.name, "rb") as f:
- mail_from = settings.resend_contact_sales_from
- mail_to = self.current_user.email
- body = f"""
-
-
-
-
Quivr's ingestion process has been completed. The processed file is attached.
-
-
Task: {task_name}
-
-
Output: {custom_message}
-
-
-
-
- """
- if brain_id:
- body += f"You can find the file
here.
"
- body += """
-
-
Please let us know if you have any questions or need further assistance.
-
-
The Quivr Team
-
- """
- params = {
- "from": mail_from,
- "to": [mail_to],
- "subject": "Quivr Ingestion Processed",
- "reply_to": "no-reply@quivr.app",
- "html": body,
- "attachments": [{"filename": filename, "content": list(f.read())}],
- }
- logger.info(f"Sending email to {mail_to} with file {filename}")
- send_email(params)
-
- async def uploadfile_to_file(self, uploadFile: UploadFile):
- # Transform the UploadFile object to a file object with same name and content
- tmp_file = NamedTemporaryFile(delete=False)
- tmp_file.write(uploadFile.file.read())
- tmp_file.flush() # Make sure all data is written to disk
- return tmp_file
-
- async def create_and_upload_processed_file(
- self, processed_content: str, original_filename: str, file_description: str
- ) -> dict:
- """Handles creation and uploading of the processed file."""
- # remove any special characters from the filename that aren't http safe
-
- new_filename = (
- original_filename.split(".")[0]
- + "_"
- + file_description.lower().replace(" ", "_")
- + "_"
- + str(random.randint(1000, 9999))
- + ".pdf"
- )
- new_filename = unidecode(new_filename)
- new_filename = re.sub(
- "[^{}0-9a-zA-Z]".format(re.escape(string.punctuation)), "", new_filename
- )
-
- self.generate_pdf(
- new_filename,
- f"{file_description} of {original_filename}",
- processed_content,
- )
-
- content_io = BytesIO()
- with open(new_filename, "rb") as f:
- content_io.write(f.read())
- content_io.seek(0)
-
- file_to_upload = UploadFile(
- filename=new_filename,
- file=content_io,
- headers={"content-type": "application/pdf"},
- )
-
- if self.input.outputs.email.activated:
- await self.send_output_by_email(
- file_to_upload,
- new_filename,
- "Summary",
- f"{file_description} of {original_filename}",
- brain_id=(
- self.input.outputs.brain.value
- if (
- self.input.outputs.brain.activated
- and self.input.outputs.brain.value
- )
- else None
- ),
- )
-
- # Reset to start of file before upload
- file_to_upload.file.seek(0)
- if self.input.outputs.brain.activated:
- await upload_file(
- uploadFile=file_to_upload,
- brain_id=self.input.outputs.brain.value,
- current_user=self.current_user,
- chat_id=None,
- )
-
- os.remove(new_filename)
-
- return {"message": f"{file_description} generated successfully"}
diff --git a/backend/api/quivr_api/modules/assistant/ito/summary.py b/backend/api/quivr_api/modules/assistant/ito/summary.py
deleted file mode 100644
index 5ef1ef356..000000000
--- a/backend/api/quivr_api/modules/assistant/ito/summary.py
+++ /dev/null
@@ -1,196 +0,0 @@
-import tempfile
-from typing import List
-
-from fastapi import UploadFile
-from langchain.chains import (
- MapReduceDocumentsChain,
- ReduceDocumentsChain,
- StuffDocumentsChain,
-)
-from langchain.chains.llm import LLMChain
-from langchain_community.chat_models import ChatLiteLLM
-from langchain_community.document_loaders import UnstructuredPDFLoader
-from langchain_core.prompts import PromptTemplate
-from langchain_text_splitters import CharacterTextSplitter
-from quivr_api.logger import get_logger
-from quivr_api.modules.assistant.dto.inputs import InputAssistant
-from quivr_api.modules.assistant.dto.outputs import (
- AssistantOutput,
- InputFile,
- Inputs,
- OutputBrain,
- OutputEmail,
- Outputs,
-)
-from quivr_api.modules.assistant.ito.ito import ITO
-from quivr_api.modules.user.entity.user_identity import UserIdentity
-
-logger = get_logger(__name__)
-
-
-class SummaryAssistant(ITO):
-
- def __init__(
- self,
- input: InputAssistant,
- files: List[UploadFile] = None,
- current_user: UserIdentity = None,
- **kwargs,
- ):
- super().__init__(
- input=input,
- files=files,
- current_user=current_user,
- **kwargs,
- )
-
- def check_input(self):
- if not self.files:
- raise ValueError("No file was uploaded")
- if len(self.files) > 1:
- raise ValueError("Only one file can be uploaded")
- if not self.input.inputs.files:
- raise ValueError("No files key were given in the input")
- if len(self.input.inputs.files) > 1:
- raise ValueError("Only one file can be uploaded")
- if not self.input.inputs.files[0].key == "doc_to_summarize":
- raise ValueError("The key of the file should be doc_to_summarize")
- if not self.input.inputs.files[0].value:
- raise ValueError("No file was uploaded")
- # Check if name of file is same as the key
- if not self.input.inputs.files[0].value == self.files[0].filename:
- raise ValueError(
- "The key of the file should be the same as the name of the file"
- )
- if not (
- self.input.outputs.brain.activated or self.input.outputs.email.activated
- ):
- raise ValueError("No output was selected")
- return True
-
- async def process_assistant(self):
-
- try:
- self.increase_usage_user()
- except Exception as e:
- logger.error(f"Error increasing usage: {e}")
- return {"error": str(e)}
-
- # Create a temporary file with the uploaded file as a temporary file and then pass it to the loader
- tmp_file = tempfile.NamedTemporaryFile(delete=False)
-
- # Write the file to the temporary file
- tmp_file.write(self.files[0].file.read())
-
- # Now pass the path of the temporary file to the loader
-
- loader = UnstructuredPDFLoader(tmp_file.name)
-
- tmp_file.close()
-
- data = loader.load()
-
- llm = ChatLiteLLM(model="gpt-4o", max_tokens=2000)
-
- map_template = """The following is a document that has been divided into multiple sections:
- {docs}
-
- Please carefully analyze each section and identify the following:
-
- 1. Main Themes: What are the overarching ideas or topics in this section?
- 2. Key Points: What are the most important facts, arguments, or ideas presented in this section?
- 3. Important Information: Are there any crucial details that stand out? This could include data, quotes, specific events, entity, or other relevant information.
- 4. People: Who are the key individuals mentioned in this section? What roles do they play?
- 5. Reasoning: What logic or arguments are used to support the key points?
- 6. Chapters: If the document is divided into chapters, what is the main focus of each chapter?
-
- Remember to consider the language and context of the document. This will help in understanding the nuances and subtleties of the text."""
- map_prompt = PromptTemplate.from_template(map_template)
- map_chain = LLMChain(llm=llm, prompt=map_prompt)
-
- # Reduce
- reduce_template = """The following is a set of summaries for parts of the document:
- {docs}
- Take these and distill it into a final, consolidated summary of the document. Make sure to include the main themes, key points, and important information such as data, quotes,people and specific events.
- Use markdown such as bold, italics, underlined. For example, **bold**, *italics*, and _underlined_ to highlight key points.
- Please provide the final summary with sections using bold headers.
- Sections should always be Summary and Key Points, but feel free to add more sections as needed.
- Always use bold text for the sections headers.
- Keep the same language as the documents.
- Answer:"""
- reduce_prompt = PromptTemplate.from_template(reduce_template)
-
- # Run chain
- reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
-
- # Takes a list of documents, combines them into a single string, and passes this to an LLMChain
- combine_documents_chain = StuffDocumentsChain(
- llm_chain=reduce_chain, document_variable_name="docs"
- )
-
- # Combines and iteratively reduces the mapped documents
- reduce_documents_chain = ReduceDocumentsChain(
- # This is final chain that is called.
- combine_documents_chain=combine_documents_chain,
- # If documents exceed context for `StuffDocumentsChain`
- collapse_documents_chain=combine_documents_chain,
- # The maximum number of tokens to group documents into.
- token_max=4000,
- )
-
- # Combining documents by mapping a chain over them, then combining results
- map_reduce_chain = MapReduceDocumentsChain(
- # Map chain
- llm_chain=map_chain,
- # Reduce chain
- reduce_documents_chain=reduce_documents_chain,
- # The variable name in the llm_chain to put the documents in
- document_variable_name="docs",
- # Return the results of the map steps in the output
- return_intermediate_steps=False,
- )
-
- text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
- chunk_size=1000, chunk_overlap=100
- )
- split_docs = text_splitter.split_documents(data)
-
- content = map_reduce_chain.run(split_docs)
-
- return await self.create_and_upload_processed_file(
- content, self.files[0].filename, "Summary"
- )
-
-
-def summary_inputs():
- output = AssistantOutput(
- name="Summary",
- description="Summarize a set of documents",
- tags=["new"],
- input_description="One document to summarize",
- output_description="A summary of the document with key points and main themes",
- icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
- inputs=Inputs(
- files=[
- InputFile(
- key="doc_to_summarize",
- allowed_extensions=["pdf"],
- required=True,
- description="The document to summarize",
- )
- ]
- ),
- outputs=Outputs(
- brain=OutputBrain(
- required=True,
- description="The brain to which upload the document",
- type="uuid",
- ),
- email=OutputEmail(
- required=True,
- description="Send the document by email",
- type="str",
- ),
- ),
- )
- return output
diff --git a/backend/api/quivr_api/modules/assistant/ito/utils/logo.png b/backend/api/quivr_api/modules/assistant/ito/utils/logo.png
deleted file mode 100644
index 5e5672d8c..000000000
Binary files a/backend/api/quivr_api/modules/assistant/ito/utils/logo.png and /dev/null differ
diff --git a/backend/api/quivr_api/modules/assistant/repository/__init__.py b/backend/api/quivr_api/modules/assistant/repository/__init__.py
index 560e87c0e..e69de29bb 100644
--- a/backend/api/quivr_api/modules/assistant/repository/__init__.py
+++ b/backend/api/quivr_api/modules/assistant/repository/__init__.py
@@ -1 +0,0 @@
-from .assistant_interface import AssistantInterface
diff --git a/backend/api/quivr_api/modules/assistant/repository/assistant_interface.py b/backend/api/quivr_api/modules/assistant/repository/assistant_interface.py
deleted file mode 100644
index 1740458bf..000000000
--- a/backend/api/quivr_api/modules/assistant/repository/assistant_interface.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import List
-
-from quivr_api.modules.assistant.entity.assistant import AssistantEntity
-
-
-class AssistantInterface(ABC):
-
- @abstractmethod
- def get_all_assistants(self) -> List[AssistantEntity]:
- """
- Get all the knowledge in a brain
- Args:
- brain_id (UUID): The id of the brain
- """
- pass
diff --git a/backend/api/quivr_api/modules/assistant/ito/__init__.py b/backend/api/quivr_api/modules/assistant/repository/interfaces/__init__.py
similarity index 100%
rename from backend/api/quivr_api/modules/assistant/ito/__init__.py
rename to backend/api/quivr_api/modules/assistant/repository/interfaces/__init__.py
diff --git a/backend/api/quivr_api/modules/assistant/repository/interfaces/task_interface.py b/backend/api/quivr_api/modules/assistant/repository/interfaces/task_interface.py
new file mode 100644
index 000000000..74f2046c6
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/repository/interfaces/task_interface.py
@@ -0,0 +1,32 @@
+from abc import ABC, abstractmethod
+from typing import List
+from uuid import UUID
+
+from quivr_api.modules.assistant.dto.inputs import CreateTask
+from quivr_api.modules.assistant.entity.task_entity import Task
+
+
+class TasksInterface(ABC):
+ @abstractmethod
+ def create_task(self, task: CreateTask) -> Task:
+ pass
+
+ @abstractmethod
+ def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
+ pass
+
+ @abstractmethod
+ def delete_task(self, task_id: UUID, user_id: UUID) -> None:
+ pass
+
+ @abstractmethod
+ def get_tasks_by_user_id(self, user_id: UUID) -> List[Task]:
+ pass
+
+ @abstractmethod
+ def update_task(self, task_id: int, task: dict) -> None:
+ pass
+
+ @abstractmethod
+ def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
+ pass
diff --git a/backend/api/quivr_api/modules/assistant/repository/tasks.py b/backend/api/quivr_api/modules/assistant/repository/tasks.py
new file mode 100644
index 000000000..7977a2f56
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/repository/tasks.py
@@ -0,0 +1,82 @@
+from typing import Sequence
+from uuid import UUID
+
+from sqlalchemy import exc
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from quivr_api.modules.assistant.dto.inputs import CreateTask
+from quivr_api.modules.assistant.entity.task_entity import Task
+from quivr_api.modules.dependencies import BaseRepository
+from quivr_api.modules.upload.service.generate_file_signed_url import (
+ generate_file_signed_url,
+)
+
+
+class TasksRepository(BaseRepository):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+
+ async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
+ try:
+ task_to_create = Task(
+ assistant_id=task.assistant_id,
+ pretty_id=task.pretty_id,
+ user_id=user_id,
+ settings=task.settings,
+ )
+ self.session.add(task_to_create)
+ await self.session.commit()
+ except exc.IntegrityError:
+ await self.session.rollback()
+ raise Exception()
+
+ await self.session.refresh(task_to_create)
+ return task_to_create
+
+ async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
+ query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
+ response = await self.session.exec(query)
+ return response.one()
+
+ async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
+ query = select(Task).where(Task.user_id == user_id)
+ response = await self.session.exec(query)
+ return response.all()
+
+ async def delete_task(self, task_id: int, user_id: UUID) -> None:
+ query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
+ response = await self.session.exec(query)
+ task = response.one()
+ if task:
+ await self.session.delete(task)
+ await self.session.commit()
+ else:
+ raise Exception()
+
+ async def update_task(self, task_id: int, task_updates: dict) -> None:
+ query = select(Task).where(Task.id == task_id)
+ response = await self.session.exec(query)
+ task = response.one()
+ if task:
+ for key, value in task_updates.items():
+ setattr(task, key, value)
+ await self.session.commit()
+ else:
+ raise Exception("Task not found")
+
+ async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
+ query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
+ response = await self.session.exec(query)
+ task = response.one()
+
+ path = f"{task.assistant_id}/{task.pretty_id}/output.pdf"
+
+ try:
+ signed_url = generate_file_signed_url(path)
+ if signed_url and "signedURL" in signed_url:
+ return signed_url["signedURL"]
+ else:
+ raise Exception()
+ except Exception:
+ return "error"
diff --git a/backend/api/quivr_api/modules/assistant/service/assistant.py b/backend/api/quivr_api/modules/assistant/service/assistant.py
deleted file mode 100644
index e4c013d6f..000000000
--- a/backend/api/quivr_api/modules/assistant/service/assistant.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from quivr_api.modules.assistant.entity.assistant import AssistantEntity
-from quivr_api.modules.assistant.repository.assistant_interface import (
- AssistantInterface,
-)
-from quivr_api.modules.dependencies import get_supabase_client
-
-
-class Assistant(AssistantInterface):
- def __init__(self):
- supabase_client = get_supabase_client()
- self.db = supabase_client
-
- def get_all_assistants(self):
- response = self.db.from_("assistants").select("*").execute()
-
- if response.data:
- return response.data
-
- return []
-
- def get_assistant_by_id(self, ingestion_id) -> AssistantEntity:
- response = (
- self.db.from_("assistants")
- .select("*")
- .filter("id", "eq", ingestion_id)
- .execute()
- )
-
- if response.data:
- return AssistantEntity(**response.data[0])
-
- return None
diff --git a/backend/api/quivr_api/modules/assistant/ito/utils/__init__.py b/backend/api/quivr_api/modules/assistant/services/__init__.py
similarity index 100%
rename from backend/api/quivr_api/modules/assistant/ito/utils/__init__.py
rename to backend/api/quivr_api/modules/assistant/services/__init__.py
diff --git a/backend/api/quivr_api/modules/assistant/services/tasks_service.py b/backend/api/quivr_api/modules/assistant/services/tasks_service.py
new file mode 100644
index 000000000..e7df1f3a6
--- /dev/null
+++ b/backend/api/quivr_api/modules/assistant/services/tasks_service.py
@@ -0,0 +1,32 @@
+from typing import Sequence
+from uuid import UUID
+
+from quivr_api.modules.assistant.dto.inputs import CreateTask
+from quivr_api.modules.assistant.entity.task_entity import Task
+from quivr_api.modules.assistant.repository.tasks import TasksRepository
+from quivr_api.modules.dependencies import BaseService
+
+
+class TasksService(BaseService[TasksRepository]):
+ repository_cls = TasksRepository
+
+ def __init__(self, repository: TasksRepository):
+ self.repository = repository
+
+ async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
+ return await self.repository.create_task(task, user_id)
+
+ async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
+ return await self.repository.get_task_by_id(task_id, user_id)
+
+ async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
+ return await self.repository.get_tasks_by_user_id(user_id)
+
+ async def delete_task(self, task_id: int, user_id: UUID) -> None:
+ return await self.repository.delete_task(task_id, user_id)
+
+ async def update_task(self, task_id: int, task: dict) -> None:
+ return await self.repository.update_task(task_id, task)
+
+ async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
+ return await self.repository.get_download_link_task(task_id, user_id)
diff --git a/backend/api/quivr_api/modules/brain/controller/__init__.py b/backend/api/quivr_api/modules/brain/controller/__init__.py
index 7e54fbb96..98f5cd9dc 100644
--- a/backend/api/quivr_api/modules/brain/controller/__init__.py
+++ b/backend/api/quivr_api/modules/brain/controller/__init__.py
@@ -1 +1,5 @@
from .brain_routes import brain_router
+
+__all__ = [
+ "brain_router",
+]
diff --git a/backend/api/quivr_api/modules/brain/dto/inputs.py b/backend/api/quivr_api/modules/brain/dto/inputs.py
index 632cd9794..bbdbb1801 100644
--- a/backend/api/quivr_api/modules/brain/dto/inputs.py
+++ b/backend/api/quivr_api/modules/brain/dto/inputs.py
@@ -2,6 +2,7 @@ from typing import Optional
from uuid import UUID
from pydantic import BaseModel
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import BrainType
from quivr_api.modules.brain.entity.integration_brain import IntegrationType
diff --git a/backend/api/quivr_api/modules/brain/integrations/Big/Brain.py b/backend/api/quivr_api/modules/brain/integrations/Big/Brain.py
index 0c7c61297..141f7de7c 100644
--- a/backend/api/quivr_api/modules/brain/integrations/Big/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/Big/Brain.py
@@ -11,6 +11,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
diff --git a/backend/api/quivr_api/modules/brain/integrations/Claude/Brain.py b/backend/api/quivr_api/modules/brain/integrations/Claude/Brain.py
index 14cf00236..25732779e 100644
--- a/backend/api/quivr_api/modules/brain/integrations/Claude/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/Claude/Brain.py
@@ -4,6 +4,7 @@ from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
+
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
diff --git a/backend/api/quivr_api/modules/brain/integrations/GPT4/Brain.py b/backend/api/quivr_api/modules/brain/integrations/GPT4/Brain.py
index f643de065..0083b48cc 100644
--- a/backend/api/quivr_api/modules/brain/integrations/GPT4/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/GPT4/Brain.py
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
diff --git a/backend/api/quivr_api/modules/brain/integrations/Proxy/Brain.py b/backend/api/quivr_api/modules/brain/integrations/Proxy/Brain.py
index 2816d6e57..4d5baa142 100644
--- a/backend/api/quivr_api/modules/brain/integrations/Proxy/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/Proxy/Brain.py
@@ -4,6 +4,7 @@ from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
diff --git a/backend/api/quivr_api/modules/brain/integrations/SQL/Brain.py b/backend/api/quivr_api/modules/brain/integrations/SQL/Brain.py
index 37509f44a..12a01d4fb 100644
--- a/backend/api/quivr_api/modules/brain/integrations/SQL/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/SQL/Brain.py
@@ -7,6 +7,7 @@ from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
+
from quivr_api.modules.brain.integrations.SQL.SQL_connector import SQLConnector
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.brain.repository.integration_brains import IntegrationBrain
@@ -85,7 +86,6 @@ class SQLBrain(KnowledgeBrainQA, IntegrationBrain):
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> AsyncIterable:
-
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
diff --git a/backend/api/quivr_api/modules/brain/integrations/Self/Brain.py b/backend/api/quivr_api/modules/brain/integrations/Self/Brain.py
index 867447653..6a992f687 100644
--- a/backend/api/quivr_api/modules/brain/integrations/Self/Brain.py
+++ b/backend/api/quivr_api/modules/brain/integrations/Self/Brain.py
@@ -12,13 +12,14 @@ from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
+from typing_extensions import TypedDict
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.dependencies import get_service
-from typing_extensions import TypedDict
# Post-processing
@@ -210,13 +211,11 @@ class SelfBrain(KnowledgeBrainQA):
return question_rewriter
def get_chain(self):
-
graph = self.create_graph()
return graph
def create_graph(self):
-
workflow = StateGraph(GraphState)
# Define the nodes
diff --git a/backend/api/quivr_api/modules/brain/repository/interfaces/__init__.py b/backend/api/quivr_api/modules/brain/repository/interfaces/__init__.py
index aab7d31bb..15163d6c6 100644
--- a/backend/api/quivr_api/modules/brain/repository/interfaces/__init__.py
+++ b/backend/api/quivr_api/modules/brain/repository/interfaces/__init__.py
@@ -2,5 +2,7 @@
from .brains_interface import BrainsInterface
from .brains_users_interface import BrainsUsersInterface
from .brains_vectors_interface import BrainsVectorsInterface
-from .integration_brains_interface import (IntegrationBrainInterface,
- IntegrationDescriptionInterface)
+from .integration_brains_interface import (
+ IntegrationBrainInterface,
+ IntegrationDescriptionInterface,
+)
diff --git a/backend/api/quivr_api/modules/brain/repository/interfaces/integration_brains_interface.py b/backend/api/quivr_api/modules/brain/repository/interfaces/integration_brains_interface.py
index 60e187488..8f6867514 100644
--- a/backend/api/quivr_api/modules/brain/repository/interfaces/integration_brains_interface.py
+++ b/backend/api/quivr_api/modules/brain/repository/interfaces/integration_brains_interface.py
@@ -38,7 +38,6 @@ class IntegrationBrainInterface(ABC):
class IntegrationDescriptionInterface(ABC):
-
@abstractmethod
def get_integration_description(
self, integration_id: UUID
diff --git a/backend/api/quivr_api/modules/brain/service/brain_authorization_service.py b/backend/api/quivr_api/modules/brain/service/brain_authorization_service.py
index e9a69290e..9583c1239 100644
--- a/backend/api/quivr_api/modules/brain/service/brain_authorization_service.py
+++ b/backend/api/quivr_api/modules/brain/service/brain_authorization_service.py
@@ -2,6 +2,7 @@ from typing import List, Optional, Union
from uuid import UUID
from fastapi import Depends, HTTPException, status
+
from quivr_api.middlewares.auth.auth_bearer import get_current_user
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
from quivr_api.modules.brain.service.brain_service import BrainService
@@ -13,7 +14,7 @@ brain_service = BrainService()
def has_brain_authorization(
- required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner
+ required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner,
):
"""
Decorator to check if the user has the required role(s) for the brain
diff --git a/backend/api/quivr_api/modules/brain/service/brain_user_service.py b/backend/api/quivr_api/modules/brain/service/brain_user_service.py
index b1bf15038..031cfb8a3 100644
--- a/backend/api/quivr_api/modules/brain/service/brain_user_service.py
+++ b/backend/api/quivr_api/modules/brain/service/brain_user_service.py
@@ -2,6 +2,7 @@ from typing import List
from uuid import UUID
from fastapi import HTTPException
+
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import (
BrainEntity,
diff --git a/backend/api/quivr_api/modules/brain/service/utils/format_chat_history.py b/backend/api/quivr_api/modules/brain/service/utils/format_chat_history.py
index a66cfab5e..0b3d3c795 100644
--- a/backend/api/quivr_api/modules/brain/service/utils/format_chat_history.py
+++ b/backend/api/quivr_api/modules/brain/service/utils/format_chat_history.py
@@ -1,6 +1,7 @@
from typing import List, Tuple
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
+
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
diff --git a/backend/api/quivr_api/modules/brain/service/utils/get_prompt_to_use_id.py b/backend/api/quivr_api/modules/brain/service/utils/get_prompt_to_use_id.py
index 51614b53e..e5d118bf0 100644
--- a/backend/api/quivr_api/modules/brain/service/utils/get_prompt_to_use_id.py
+++ b/backend/api/quivr_api/modules/brain/service/utils/get_prompt_to_use_id.py
@@ -15,5 +15,7 @@ def get_prompt_to_use_id(
return (
prompt_id
if prompt_id
- else brain_service.get_brain_prompt_id(brain_id) if brain_id else None
+ else brain_service.get_brain_prompt_id(brain_id)
+ if brain_id
+ else None
)
diff --git a/backend/api/quivr_api/modules/brain/service/utils/validate_brain.py b/backend/api/quivr_api/modules/brain/service/utils/validate_brain.py
index 43ec8e025..e69de29bb 100644
--- a/backend/api/quivr_api/modules/brain/service/utils/validate_brain.py
+++ b/backend/api/quivr_api/modules/brain/service/utils/validate_brain.py
@@ -1,2 +0,0 @@
-from fastapi import HTTPException
-from quivr_api.modules.brain.dto.inputs import CreateBrainProperties
diff --git a/backend/api/quivr_api/modules/chat/dto/chats.py b/backend/api/quivr_api/modules/chat/dto/chats.py
index e900602d1..a04a6dcc7 100644
--- a/backend/api/quivr_api/modules/chat/dto/chats.py
+++ b/backend/api/quivr_api/modules/chat/dto/chats.py
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
from uuid import UUID
from pydantic import BaseModel
+
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.notification.entity.notification import Notification
diff --git a/backend/api/quivr_api/modules/chat/entity/chat.py b/backend/api/quivr_api/modules/chat/entity/chat.py
index f989f37f3..965b38da8 100644
--- a/backend/api/quivr_api/modules/chat/entity/chat.py
+++ b/backend/api/quivr_api/modules/chat/entity/chat.py
@@ -2,12 +2,12 @@ from datetime import datetime
from typing import List
from uuid import UUID
+from sqlalchemy.ext.asyncio import AsyncAttrs
+from sqlmodel import JSON, TIMESTAMP, Column, Field, Relationship, SQLModel, text
+from sqlmodel import UUID as PGUUID
+
from quivr_api.modules.brain.entity.brain_entity import Brain
from quivr_api.modules.user.entity.user_identity import User
-from sqlalchemy.ext.asyncio import AsyncAttrs
-from sqlmodel import JSON, TIMESTAMP
-from sqlmodel import UUID as PGUUID
-from sqlmodel import Column, Field, Relationship, SQLModel, text
class Chat(SQLModel, table=True):
diff --git a/backend/api/quivr_api/modules/knowledge/controller/__init__.py b/backend/api/quivr_api/modules/knowledge/controller/__init__.py
index 23c692c11..911883cdc 100644
--- a/backend/api/quivr_api/modules/knowledge/controller/__init__.py
+++ b/backend/api/quivr_api/modules/knowledge/controller/__init__.py
@@ -1 +1 @@
-from .knowledge_routes import knowledge_router
\ No newline at end of file
+from .knowledge_routes import knowledge_router
diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage.py b/backend/api/quivr_api/modules/knowledge/repository/storage.py
index 0e58e25d9..e53165e22 100644
--- a/backend/api/quivr_api/modules/knowledge/repository/storage.py
+++ b/backend/api/quivr_api/modules/knowledge/repository/storage.py
@@ -85,3 +85,5 @@ class SupabaseS3Storage(StorageInterface):
return response
except Exception as e:
logger.error(e)
+ raise e
+
diff --git a/backend/api/quivr_api/modules/models/controller/model_routes.py b/backend/api/quivr_api/modules/models/controller/model_routes.py
index a5370c90f..75b649a33 100644
--- a/backend/api/quivr_api/modules/models/controller/model_routes.py
+++ b/backend/api/quivr_api/modules/models/controller/model_routes.py
@@ -1,6 +1,7 @@
from typing import Annotated, List
from fastapi import APIRouter, Depends
+
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.dependencies import get_service
diff --git a/backend/api/quivr_api/modules/notification/dto/__init__.py b/backend/api/quivr_api/modules/notification/dto/__init__.py
index 726ac989c..2d81927d4 100644
--- a/backend/api/quivr_api/modules/notification/dto/__init__.py
+++ b/backend/api/quivr_api/modules/notification/dto/__init__.py
@@ -1 +1 @@
-from .inputs import NotificationUpdatableProperties
\ No newline at end of file
+from .inputs import NotificationUpdatableProperties
diff --git a/backend/api/quivr_api/modules/prompt/controller/prompt_routes.py b/backend/api/quivr_api/modules/prompt/controller/prompt_routes.py
index 3aa5b6c75..82e25a4bf 100644
--- a/backend/api/quivr_api/modules/prompt/controller/prompt_routes.py
+++ b/backend/api/quivr_api/modules/prompt/controller/prompt_routes.py
@@ -1,6 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, Depends
+
from quivr_api.middlewares.auth import AuthBearer
from quivr_api.modules.prompt.entity.prompt import (
CreatePromptProperties,
diff --git a/backend/api/quivr_api/modules/prompt/entity/__init__.py b/backend/api/quivr_api/modules/prompt/entity/__init__.py
index f3437a0ca..324aeee09 100644
--- a/backend/api/quivr_api/modules/prompt/entity/__init__.py
+++ b/backend/api/quivr_api/modules/prompt/entity/__init__.py
@@ -1 +1,7 @@
-from .prompt import Prompt, PromptStatusEnum, CreatePromptProperties, PromptUpdatableProperties, DeletePromptResponse
\ No newline at end of file
+from .prompt import (
+ CreatePromptProperties,
+ DeletePromptResponse,
+ Prompt,
+ PromptStatusEnum,
+ PromptUpdatableProperties,
+)
diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
index c905fb5ba..2f40c140c 100644
--- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
@@ -4,6 +4,7 @@ import requests
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from msal import ConfidentialClientApplication
+
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
index ecc88a5b3..84599965c 100644
--- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
@@ -3,6 +3,7 @@ import os
import requests
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
+
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
diff --git a/backend/api/quivr_api/modules/sync/controller/successfull_connection.py b/backend/api/quivr_api/modules/sync/controller/successfull_connection.py
index ffdb877e8..0e9f00852 100644
--- a/backend/api/quivr_api/modules/sync/controller/successfull_connection.py
+++ b/backend/api/quivr_api/modules/sync/controller/successfull_connection.py
@@ -50,4 +50,4 @@ successfullConnectionPage = """