mirror of
https://github.com/StanGirard/quivr.git
synced 2024-10-04 00:33:03 +03:00
feat(assistants): mock api (#3195)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
4390d318a2
commit
282fa0e3f8
1
.gitignore
vendored
1
.gitignore
vendored
@ -102,3 +102,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
|
||||
# Tox
|
||||
.tox
|
||||
Pipfile
|
||||
*.pkl
|
||||
|
@ -19,7 +19,6 @@ dependencies = [
|
||||
"pydantic-settings>=2.4.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"unidecode>=1.3.8",
|
||||
"fpdf>=1.7.2",
|
||||
"colorlog>=6.8.2",
|
||||
"posthog>=3.5.0",
|
||||
"pyinstrument>=4.7.2",
|
||||
|
@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler
|
||||
|
||||
from colorlog import (
|
||||
ColoredFormatter,
|
||||
) # You need to install this package: pip install colorlog
|
||||
)
|
||||
|
||||
|
||||
def get_logger(logger_name, log_file="application.log"):
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.modules.analytics.entity.analytics import Range
|
||||
from quivr_api.modules.analytics.service.analytics_service import AnalyticsService
|
||||
|
@ -1,16 +1,20 @@
|
||||
from datetime import date
|
||||
from enum import IntEnum
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import date
|
||||
|
||||
|
||||
class Range(IntEnum):
|
||||
WEEK = 7
|
||||
MONTH = 30
|
||||
QUARTER = 90
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
date: date
|
||||
usage_count: int
|
||||
|
||||
|
||||
class BrainsUsages(BaseModel):
|
||||
usages: List[Usage]
|
||||
usages: List[Usage]
|
||||
|
@ -11,5 +11,4 @@ class AnalyticsService:
|
||||
self.repository = Analytics()
|
||||
|
||||
def get_brains_usages(self, user_id, graph_range, brain_id=None):
|
||||
|
||||
return self.repository.get_brains_usages(user_id, graph_range, brain_id)
|
||||
|
@ -3,6 +3,7 @@ from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.api_key.dto.outputs import ApiKeyInfo
|
||||
|
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
|
||||
from quivr_api.modules.api_key.repository.api_keys import ApiKeys
|
||||
|
@ -1 +1,6 @@
|
||||
# noqa:
|
||||
from .assistant_routes import assistant_router
|
||||
|
||||
__all__ = [
|
||||
"assistant_router",
|
||||
]
|
||||
|
@ -1,63 +1,176 @@
|
||||
from typing import List
|
||||
import io
|
||||
from typing import Annotated, List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
|
||||
from quivr_api.celery_config import celery
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.modules.assistant.controller.assistants_definition import (
|
||||
assistants,
|
||||
validate_assistant_input,
|
||||
)
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import AssistantOutput
|
||||
from quivr_api.modules.assistant.ito.difference import DifferenceAssistant
|
||||
from quivr_api.modules.assistant.ito.summary import SummaryAssistant, summary_inputs
|
||||
from quivr_api.modules.assistant.service.assistant import Assistant
|
||||
from quivr_api.modules.assistant.entity.assistant_entity import (
|
||||
AssistantSettings,
|
||||
)
|
||||
from quivr_api.modules.assistant.services.tasks_service import TasksService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.upload.service.upload_file import (
|
||||
upload_file_storage,
|
||||
)
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
assistant_router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
assistant_service = Assistant()
|
||||
|
||||
assistant_router = APIRouter()
|
||||
|
||||
|
||||
TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))]
|
||||
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def list_assistants(
|
||||
async def get_assistants(
|
||||
request: Request,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> List[AssistantOutput]:
|
||||
"""
|
||||
Retrieve and list all the knowledge in a brain.
|
||||
"""
|
||||
logger.info("Getting assistants")
|
||||
|
||||
summary = summary_inputs()
|
||||
# difference = difference_inputs()
|
||||
# crawler = crawler_inputs()
|
||||
return [summary]
|
||||
return assistants
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def get_tasks(
|
||||
request: Request,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
logger.info("Getting tasks")
|
||||
return await tasks_service.get_tasks_by_user_id(current_user.id)
|
||||
|
||||
|
||||
@assistant_router.post(
|
||||
"/assistant/process",
|
||||
"/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def create_task(
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
request: Request,
|
||||
input: InputAssistant,
|
||||
files: List[UploadFile] = None,
|
||||
):
|
||||
assistant = next(
|
||||
(assistant for assistant in assistants if assistant.id == input.id), None
|
||||
)
|
||||
if assistant is None:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
is_valid, validation_errors = validate_assistant_input(input, assistant)
|
||||
if not is_valid:
|
||||
for error in validation_errors:
|
||||
print(error)
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
else:
|
||||
print("Assistant input is valid.")
|
||||
notification_uuid = uuid4()
|
||||
|
||||
# Process files dynamically
|
||||
for upload_file in files:
|
||||
file_name_path = f"{input.id}/{notification_uuid}/{upload_file.filename}"
|
||||
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
|
||||
try:
|
||||
await upload_file_storage(buff_reader, file_name_path)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in upload_route {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to upload file to storage. {e}"
|
||||
)
|
||||
|
||||
task = CreateTask(
|
||||
assistant_id=input.id,
|
||||
pretty_id=str(notification_uuid),
|
||||
settings=input.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
task_created = await tasks_service.create_task(task, current_user.id)
|
||||
|
||||
celery.send_task(
|
||||
"process_assistant_task",
|
||||
kwargs={
|
||||
"assistant_id": input.id,
|
||||
"notification_uuid": notification_uuid,
|
||||
"task_id": task_created.id,
|
||||
"user_id": str(current_user.id),
|
||||
},
|
||||
)
|
||||
return task_created
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/task/{task_id}",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def process_assistant(
|
||||
input: InputAssistant,
|
||||
files: List[UploadFile] = None,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
async def get_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
if input.name.lower() == "summary":
|
||||
summary_assistant = SummaryAssistant(
|
||||
input=input, files=files, current_user=current_user
|
||||
)
|
||||
try:
|
||||
summary_assistant.check_input()
|
||||
return await summary_assistant.process_assistant()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
elif input.name.lower() == "difference":
|
||||
difference_assistant = DifferenceAssistant(
|
||||
input=input, files=files, current_user=current_user
|
||||
)
|
||||
try:
|
||||
difference_assistant.check_input()
|
||||
return await difference_assistant.process_assistant()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"message": "Assistant not found"}
|
||||
return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore
|
||||
|
||||
|
||||
@assistant_router.delete(
|
||||
"/assistants/task/{task_id}",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def delete_task(
|
||||
request: Request,
|
||||
task_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
return await tasks_service.delete_task(task_id, current_user.id)
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/task/{task_id}/download",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def get_download_link_task(
|
||||
request: Request,
|
||||
task_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
return await tasks_service.get_download_link_task(task_id, current_user.id)
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/{assistant_id}/config",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
response_model=AssistantSettings,
|
||||
summary="Retrieve assistant configuration",
|
||||
description="Get the settings and file requirements for the specified assistant.",
|
||||
)
|
||||
async def get_assistant_config(
|
||||
assistant_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
):
|
||||
assistant = next(
|
||||
(assistant for assistant in assistants if assistant.id == assistant_id), None
|
||||
)
|
||||
if assistant is None:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return assistant.settings
|
||||
|
@ -0,0 +1,201 @@
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
InputFile,
|
||||
Inputs,
|
||||
Pricing,
|
||||
)
|
||||
|
||||
|
||||
def validate_assistant_input(
|
||||
assistant_input: InputAssistant, assistant_output: AssistantOutput
|
||||
):
|
||||
errors = []
|
||||
|
||||
# Validate files
|
||||
if assistant_output.inputs.files:
|
||||
required_files = [
|
||||
file for file in assistant_output.inputs.files if file.required
|
||||
]
|
||||
input_files = {
|
||||
file_input.key for file_input in (assistant_input.inputs.files or [])
|
||||
}
|
||||
for req_file in required_files:
|
||||
if req_file.key not in input_files:
|
||||
errors.append(f"Missing required file input: {req_file.key}")
|
||||
|
||||
# Validate URLs
|
||||
if assistant_output.inputs.urls:
|
||||
required_urls = [url for url in assistant_output.inputs.urls if url.required]
|
||||
input_urls = {
|
||||
url_input.key for url_input in (assistant_input.inputs.urls or [])
|
||||
}
|
||||
for req_url in required_urls:
|
||||
if req_url.key not in input_urls:
|
||||
errors.append(f"Missing required URL input: {req_url.key}")
|
||||
|
||||
# Validate texts
|
||||
if assistant_output.inputs.texts:
|
||||
required_texts = [
|
||||
text for text in assistant_output.inputs.texts if text.required
|
||||
]
|
||||
input_texts = {
|
||||
text_input.key for text_input in (assistant_input.inputs.texts or [])
|
||||
}
|
||||
for req_text in required_texts:
|
||||
if req_text.key not in input_texts:
|
||||
errors.append(f"Missing required text input: {req_text.key}")
|
||||
else:
|
||||
# Validate regex if applicable
|
||||
req_text_val = next(
|
||||
(t for t in assistant_output.inputs.texts if t.key == req_text.key),
|
||||
None,
|
||||
)
|
||||
if req_text_val and req_text_val.validation_regex:
|
||||
import re
|
||||
|
||||
input_value = next(
|
||||
(
|
||||
t.value
|
||||
for t in assistant_input.inputs.texts
|
||||
if t.key == req_text.key
|
||||
),
|
||||
"",
|
||||
)
|
||||
if not re.match(req_text_val.validation_regex, input_value):
|
||||
errors.append(
|
||||
f"Text input '{req_text.key}' does not match the required format."
|
||||
)
|
||||
|
||||
# Validate booleans
|
||||
if assistant_output.inputs.booleans:
|
||||
required_booleans = [b for b in assistant_output.inputs.booleans if b.required]
|
||||
input_booleans = {
|
||||
b_input.key for b_input in (assistant_input.inputs.booleans or [])
|
||||
}
|
||||
for req_bool in required_booleans:
|
||||
if req_bool.key not in input_booleans:
|
||||
errors.append(f"Missing required boolean input: {req_bool.key}")
|
||||
|
||||
# Validate numbers
|
||||
if assistant_output.inputs.numbers:
|
||||
required_numbers = [n for n in assistant_output.inputs.numbers if n.required]
|
||||
input_numbers = {
|
||||
n_input.key for n_input in (assistant_input.inputs.numbers or [])
|
||||
}
|
||||
for req_number in required_numbers:
|
||||
if req_number.key not in input_numbers:
|
||||
errors.append(f"Missing required number input: {req_number.key}")
|
||||
else:
|
||||
# Validate min and max
|
||||
input_value = next(
|
||||
(
|
||||
n.value
|
||||
for n in assistant_input.inputs.numbers
|
||||
if n.key == req_number.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if req_number.min is not None and input_value < req_number.min:
|
||||
errors.append(
|
||||
f"Number input '{req_number.key}' is below minimum value."
|
||||
)
|
||||
if req_number.max is not None and input_value > req_number.max:
|
||||
errors.append(
|
||||
f"Number input '{req_number.key}' exceeds maximum value."
|
||||
)
|
||||
|
||||
# Validate select_texts
|
||||
if assistant_output.inputs.select_texts:
|
||||
required_select_texts = [
|
||||
st for st in assistant_output.inputs.select_texts if st.required
|
||||
]
|
||||
input_select_texts = {
|
||||
st_input.key for st_input in (assistant_input.inputs.select_texts or [])
|
||||
}
|
||||
for req_select in required_select_texts:
|
||||
if req_select.key not in input_select_texts:
|
||||
errors.append(f"Missing required select text input: {req_select.key}")
|
||||
else:
|
||||
input_value = next(
|
||||
(
|
||||
st.value
|
||||
for st in assistant_input.inputs.select_texts
|
||||
if st.key == req_select.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if input_value not in req_select.options:
|
||||
errors.append(f"Invalid option for select text '{req_select.key}'.")
|
||||
|
||||
# Validate select_numbers
|
||||
if assistant_output.inputs.select_numbers:
|
||||
required_select_numbers = [
|
||||
sn for sn in assistant_output.inputs.select_numbers if sn.required
|
||||
]
|
||||
input_select_numbers = {
|
||||
sn_input.key for sn_input in (assistant_input.inputs.select_numbers or [])
|
||||
}
|
||||
for req_select in required_select_numbers:
|
||||
if req_select.key not in input_select_numbers:
|
||||
errors.append(f"Missing required select number input: {req_select.key}")
|
||||
else:
|
||||
input_value = next(
|
||||
(
|
||||
sn.value
|
||||
for sn in assistant_input.inputs.select_numbers
|
||||
if sn.key == req_select.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if input_value not in req_select.options:
|
||||
errors.append(
|
||||
f"Invalid option for select number '{req_select.key}'."
|
||||
)
|
||||
|
||||
# Validate brain input
|
||||
if assistant_output.inputs.brain and assistant_output.inputs.brain.required:
|
||||
if not assistant_input.inputs.brain or not assistant_input.inputs.brain.value:
|
||||
errors.append("Missing required brain input.")
|
||||
|
||||
if errors:
|
||||
return False, errors
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
assistant1 = AssistantOutput(
|
||||
id=1,
|
||||
name="Assistant 1",
|
||||
description="Assistant 1 description",
|
||||
pricing=Pricing(),
|
||||
tags=["tag1", "tag2"],
|
||||
input_description="Input description",
|
||||
output_description="Output description",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(key="file_1", description="File description"),
|
||||
InputFile(key="file_2", description="File description"),
|
||||
],
|
||||
),
|
||||
icon_url="https://example.com/icon.png",
|
||||
)
|
||||
|
||||
assistant2 = AssistantOutput(
|
||||
id=2,
|
||||
name="Assistant 2",
|
||||
description="Assistant 2 description",
|
||||
pricing=Pricing(),
|
||||
tags=["tag1", "tag2"],
|
||||
input_description="Input description",
|
||||
output_description="Output description",
|
||||
icon_url="https://example.com/icon.png",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(key="file_1", description="File description"),
|
||||
InputFile(key="file_2", description="File description"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assistants = [assistant1, assistant2]
|
@ -1,16 +1,16 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, model_validator, root_validator
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
|
||||
class EmailInput(BaseModel):
|
||||
activated: bool
|
||||
class CreateTask(BaseModel):
|
||||
pretty_id: str
|
||||
assistant_id: int
|
||||
settings: dict
|
||||
|
||||
|
||||
class BrainInput(BaseModel):
|
||||
activated: Optional[bool] = False
|
||||
value: Optional[UUID] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
@ -64,19 +64,10 @@ class Inputs(BaseModel):
|
||||
numbers: Optional[List[InputNumber]] = None
|
||||
select_texts: Optional[List[InputSelectText]] = None
|
||||
select_numbers: Optional[List[InputSelectNumber]] = None
|
||||
|
||||
|
||||
class Outputs(BaseModel):
|
||||
email: Optional[EmailInput] = None
|
||||
brain: Optional[BrainInput] = None
|
||||
|
||||
|
||||
class InputAssistant(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
inputs: Inputs
|
||||
outputs: Outputs
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def to_py_dict(cls, data):
|
||||
return json.loads(data)
|
||||
|
@ -3,6 +3,12 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BrainInput(BaseModel):
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
type: str
|
||||
|
||||
|
||||
class InputFile(BaseModel):
|
||||
key: str
|
||||
allowed_extensions: Optional[List[str]] = ["pdf"]
|
||||
@ -63,23 +69,7 @@ class Inputs(BaseModel):
|
||||
numbers: Optional[List[InputNumber]] = None
|
||||
select_texts: Optional[List[InputSelectText]] = None
|
||||
select_numbers: Optional[List[InputSelectNumber]] = None
|
||||
|
||||
|
||||
class OutputEmail(BaseModel):
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
type: str
|
||||
|
||||
|
||||
class OutputBrain(BaseModel):
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
type: str
|
||||
|
||||
|
||||
class Outputs(BaseModel):
|
||||
email: Optional[OutputEmail] = None
|
||||
brain: Optional[OutputBrain] = None
|
||||
brain: Optional[BrainInput] = None
|
||||
|
||||
|
||||
class Pricing(BaseModel):
|
||||
@ -88,6 +78,7 @@ class Pricing(BaseModel):
|
||||
|
||||
|
||||
class AssistantOutput(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
pricing: Optional[Pricing] = Pricing()
|
||||
@ -95,5 +86,4 @@ class AssistantOutput(BaseModel):
|
||||
input_description: str
|
||||
output_description: str
|
||||
inputs: Inputs
|
||||
outputs: Outputs
|
||||
icon_url: Optional[str] = None
|
||||
|
@ -1 +0,0 @@
|
||||
from .assistant import AssistantEntity
|
@ -1,11 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AssistantEntity(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brain_id_required: bool
|
||||
file_1_required: bool
|
||||
url_required: bool
|
@ -0,0 +1,33 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AssistantFileRequirement(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
required: bool = True
|
||||
accepted_types: Optional[List[str]] = None # e.g., ['text/csv', 'application/json']
|
||||
|
||||
|
||||
class AssistantInput(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
type: str # e.g., 'boolean', 'uuid', 'string'
|
||||
required: bool = True
|
||||
regex: Optional[str] = None
|
||||
options: Optional[List[Any]] = None # For predefined choices
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class AssistantSettings(BaseModel):
|
||||
inputs: List[AssistantInput]
|
||||
files: Optional[List[AssistantFileRequirement]] = None
|
||||
|
||||
|
||||
class Assistant(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
settings: AssistantSettings
|
||||
required_files: Optional[List[str]] = None # List of required file names
|
@ -0,0 +1,34 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text
|
||||
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
__tablename__ = "tasks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
),
|
||||
)
|
||||
assistant_id: int
|
||||
pretty_id: str
|
||||
user_id: UUID
|
||||
status: str = Field(default="pending")
|
||||
creation_time: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
TIMESTAMP(timezone=False),
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
),
|
||||
)
|
||||
settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
answer: str | None = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
@ -1,73 +0,0 @@
|
||||
from bs4 import BeautifulSoup as Soup
|
||||
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
Inputs,
|
||||
InputUrl,
|
||||
OutputBrain,
|
||||
OutputEmail,
|
||||
Outputs,
|
||||
)
|
||||
from quivr_api.modules.assistant.ito.ito import ITO
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CrawlerAssistant(ITO):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def process_assistant(self):
|
||||
|
||||
url = self.url
|
||||
loader = RecursiveUrlLoader(
|
||||
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
nice_url = url.split("://")[1].replace("/", "_").replace(".", "_")
|
||||
nice_url += ".txt"
|
||||
|
||||
for docs in docs:
|
||||
await self.create_and_upload_processed_file(
|
||||
docs.page_content, nice_url, "Crawler"
|
||||
)
|
||||
|
||||
|
||||
def crawler_inputs():
|
||||
output = AssistantOutput(
|
||||
name="Crawler",
|
||||
description="Crawls a website and extracts the text from the pages",
|
||||
tags=["new"],
|
||||
input_description="One URL to crawl",
|
||||
output_description="Text extracted from the pages",
|
||||
inputs=Inputs(
|
||||
urls=[
|
||||
InputUrl(
|
||||
key="url",
|
||||
required=True,
|
||||
description="The URL to crawl",
|
||||
)
|
||||
],
|
||||
),
|
||||
outputs=Outputs(
|
||||
brain=OutputBrain(
|
||||
required=True,
|
||||
description="The brain to which upload the document",
|
||||
type="uuid",
|
||||
),
|
||||
email=OutputEmail(
|
||||
required=True,
|
||||
description="Send the document by email",
|
||||
type="str",
|
||||
),
|
||||
),
|
||||
)
|
||||
return output
|
@ -1,171 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from fastapi import UploadFile
|
||||
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
||||
from llama_parse import LlamaParse
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
InputFile,
|
||||
Inputs,
|
||||
OutputBrain,
|
||||
OutputEmail,
|
||||
Outputs,
|
||||
)
|
||||
from quivr_api.modules.assistant.ito.ito import ITO
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DifferenceAssistant(ITO):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: InputAssistant,
|
||||
files: List[UploadFile] = None,
|
||||
current_user: UserIdentity = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
input=input,
|
||||
files=files,
|
||||
current_user=current_user,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def check_input(self):
|
||||
if not self.files:
|
||||
raise ValueError("No file was uploaded")
|
||||
if len(self.files) != 2:
|
||||
raise ValueError("Only two files can be uploaded")
|
||||
if not self.input.inputs.files:
|
||||
raise ValueError("No files key were given in the input")
|
||||
if len(self.input.inputs.files) != 2:
|
||||
raise ValueError("Only two files can be uploaded")
|
||||
if not self.input.inputs.files[0].key == "doc_1":
|
||||
raise ValueError("The key of the first file should be doc_1")
|
||||
if not self.input.inputs.files[1].key == "doc_2":
|
||||
raise ValueError("The key of the second file should be doc_2")
|
||||
if not self.input.inputs.files[0].value:
|
||||
raise ValueError("No file was uploaded")
|
||||
if not (
|
||||
self.input.outputs.brain.activated or self.input.outputs.email.activated
|
||||
):
|
||||
raise ValueError("No output was selected")
|
||||
return True
|
||||
|
||||
async def process_assistant(self):
|
||||
|
||||
document_1 = self.files[0]
|
||||
document_2 = self.files[1]
|
||||
|
||||
# Get the file extensions
|
||||
document_1_ext = os.path.splitext(document_1.filename)[1]
|
||||
document_2_ext = os.path.splitext(document_2.filename)[1]
|
||||
|
||||
# Create temporary files with the same extension as the original files
|
||||
document_1_tmp = tempfile.NamedTemporaryFile(
|
||||
suffix=document_1_ext, delete=False
|
||||
)
|
||||
document_2_tmp = tempfile.NamedTemporaryFile(
|
||||
suffix=document_2_ext, delete=False
|
||||
)
|
||||
|
||||
document_1_tmp.write(document_1.file.read())
|
||||
document_2_tmp.write(document_2.file.read())
|
||||
|
||||
parser = LlamaParse(
|
||||
result_type="markdown" # "markdown" and "text" are available
|
||||
)
|
||||
|
||||
document_1_llama_parsed = parser.load_data(document_1_tmp.name)
|
||||
document_2_llama_parsed = parser.load_data(document_2_tmp.name)
|
||||
|
||||
document_1_tmp.close()
|
||||
document_2_tmp.close()
|
||||
|
||||
document_1_to_langchain = document_1_llama_parsed[0].to_langchain_format()
|
||||
document_2_to_langchain = document_2_llama_parsed[0].to_langchain_format()
|
||||
|
||||
llm = ChatLiteLLM(model="gpt-4o")
|
||||
|
||||
human_prompt = """Given the following two documents, find the difference between them:
|
||||
|
||||
Document 1:
|
||||
{document_1}
|
||||
Document 2:
|
||||
{document_2}
|
||||
Difference:
|
||||
"""
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(human_prompt)
|
||||
|
||||
system_message_template = """
|
||||
You are an expert in finding the difference between two documents. You look deeply into what makes the two documents different and provide a detailed analysis if needed of the differences between the two documents.
|
||||
If no differences are found, simply say that there are no differences.
|
||||
"""
|
||||
|
||||
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
HumanMessagePromptTemplate.from_template(human_prompt),
|
||||
]
|
||||
)
|
||||
|
||||
final_inputs = {
|
||||
"document_1": document_1_to_langchain.page_content,
|
||||
"document_2": document_2_to_langchain.page_content,
|
||||
}
|
||||
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
chain = ANSWER_PROMPT | llm | output_parser
|
||||
result = chain.invoke(final_inputs)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def difference_inputs():
|
||||
output = AssistantOutput(
|
||||
name="difference",
|
||||
description="Finds the difference between two sets of documents",
|
||||
tags=["new"],
|
||||
input_description="Two documents to compare",
|
||||
output_description="The difference between the two documents",
|
||||
icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(
|
||||
key="doc_1",
|
||||
allowed_extensions=["pdf"],
|
||||
required=True,
|
||||
description="The first document to compare",
|
||||
),
|
||||
InputFile(
|
||||
key="doc_2",
|
||||
allowed_extensions=["pdf"],
|
||||
required=True,
|
||||
description="The second document to compare",
|
||||
),
|
||||
]
|
||||
),
|
||||
outputs=Outputs(
|
||||
brain=OutputBrain(
|
||||
required=True,
|
||||
description="The brain to which upload the document",
|
||||
type="uuid",
|
||||
),
|
||||
email=OutputEmail(
|
||||
required=True,
|
||||
description="Send the document by email",
|
||||
type="str",
|
||||
),
|
||||
),
|
||||
)
|
||||
return output
|
@ -1,196 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from abc import abstractmethod
|
||||
from io import BytesIO
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from unidecode import unidecode
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import SendEmailSettings
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.ito.utils.pdf_generator import PDFGenerator, PDFModel
|
||||
from quivr_api.modules.chat.controller.chat.utils import update_user_usage
|
||||
from quivr_api.modules.upload.controller.upload_routes import upload_file
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
from quivr_api.utils.send_email import send_email
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ITO(BaseModel):
|
||||
input: InputAssistant
|
||||
files: List[UploadFile]
|
||||
current_user: UserIdentity
|
||||
user_usage: Optional[UserUsage] = None
|
||||
user_settings: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: InputAssistant,
|
||||
files: List[UploadFile] = None,
|
||||
current_user: UserIdentity = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
input=input,
|
||||
files=files,
|
||||
current_user=current_user,
|
||||
**kwargs,
|
||||
)
|
||||
self.user_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
self.user_settings = self.user_usage.get_user_settings()
|
||||
self.increase_usage_user()
|
||||
|
||||
def increase_usage_user(self):
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
|
||||
update_user_usage(
|
||||
usage=self.user_usage,
|
||||
user_settings=self.user_settings,
|
||||
cost=self.calculate_pricing(),
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
return 20
|
||||
|
||||
def generate_pdf(self, filename: str, title: str, content: str):
|
||||
pdf_model = PDFModel(title=title, content=content)
|
||||
pdf = PDFGenerator(pdf_model)
|
||||
pdf.print_pdf()
|
||||
pdf.output(filename, "F")
|
||||
|
||||
@abstractmethod
|
||||
async def process_assistant(self):
|
||||
pass
|
||||
|
||||
async def send_output_by_email(
|
||||
self,
|
||||
file: UploadFile,
|
||||
filename: str,
|
||||
task_name: str,
|
||||
custom_message: str,
|
||||
brain_id: str = None,
|
||||
):
|
||||
settings = SendEmailSettings()
|
||||
file = await self.uploadfile_to_file(file)
|
||||
domain_quivr = os.getenv("QUIVR_DOMAIN", "https://chat.quivr.app/")
|
||||
|
||||
with open(file.name, "rb") as f:
|
||||
mail_from = settings.resend_contact_sales_from
|
||||
mail_to = self.current_user.email
|
||||
body = f"""
|
||||
<div style="text-align: center;">
|
||||
<img src="https://quivr-cms.s3.eu-west-3.amazonaws.com/logo_quivr_white_7e3c72620f.png" alt="Quivr Logo" style="width: 100px; height: 100px; border-radius: 50%; margin: 0 auto; display: block;">
|
||||
|
||||
<p>Quivr's ingestion process has been completed. The processed file is attached.</p>
|
||||
|
||||
<p><strong>Task:</strong> {task_name}</p>
|
||||
|
||||
<p><strong>Output:</strong> {custom_message}</p>
|
||||
<br />
|
||||
|
||||
|
||||
</div>
|
||||
"""
|
||||
if brain_id:
|
||||
body += f"<div style='text-align: center;'>You can find the file <a href='{domain_quivr}studio/{brain_id}'>here</a>.</div> <br />"
|
||||
body += """
|
||||
<div style="text-align: center;">
|
||||
<p>Please let us know if you have any questions or need further assistance.</p>
|
||||
|
||||
<p> The Quivr Team </p>
|
||||
</div>
|
||||
"""
|
||||
params = {
|
||||
"from": mail_from,
|
||||
"to": [mail_to],
|
||||
"subject": "Quivr Ingestion Processed",
|
||||
"reply_to": "no-reply@quivr.app",
|
||||
"html": body,
|
||||
"attachments": [{"filename": filename, "content": list(f.read())}],
|
||||
}
|
||||
logger.info(f"Sending email to {mail_to} with file {filename}")
|
||||
send_email(params)
|
||||
|
||||
async def uploadfile_to_file(self, uploadFile: UploadFile):
|
||||
# Transform the UploadFile object to a file object with same name and content
|
||||
tmp_file = NamedTemporaryFile(delete=False)
|
||||
tmp_file.write(uploadFile.file.read())
|
||||
tmp_file.flush() # Make sure all data is written to disk
|
||||
return tmp_file
|
||||
|
||||
async def create_and_upload_processed_file(
|
||||
self, processed_content: str, original_filename: str, file_description: str
|
||||
) -> dict:
|
||||
"""Handles creation and uploading of the processed file."""
|
||||
# remove any special characters from the filename that aren't http safe
|
||||
|
||||
new_filename = (
|
||||
original_filename.split(".")[0]
|
||||
+ "_"
|
||||
+ file_description.lower().replace(" ", "_")
|
||||
+ "_"
|
||||
+ str(random.randint(1000, 9999))
|
||||
+ ".pdf"
|
||||
)
|
||||
new_filename = unidecode(new_filename)
|
||||
new_filename = re.sub(
|
||||
"[^{}0-9a-zA-Z]".format(re.escape(string.punctuation)), "", new_filename
|
||||
)
|
||||
|
||||
self.generate_pdf(
|
||||
new_filename,
|
||||
f"{file_description} of {original_filename}",
|
||||
processed_content,
|
||||
)
|
||||
|
||||
content_io = BytesIO()
|
||||
with open(new_filename, "rb") as f:
|
||||
content_io.write(f.read())
|
||||
content_io.seek(0)
|
||||
|
||||
file_to_upload = UploadFile(
|
||||
filename=new_filename,
|
||||
file=content_io,
|
||||
headers={"content-type": "application/pdf"},
|
||||
)
|
||||
|
||||
if self.input.outputs.email.activated:
|
||||
await self.send_output_by_email(
|
||||
file_to_upload,
|
||||
new_filename,
|
||||
"Summary",
|
||||
f"{file_description} of {original_filename}",
|
||||
brain_id=(
|
||||
self.input.outputs.brain.value
|
||||
if (
|
||||
self.input.outputs.brain.activated
|
||||
and self.input.outputs.brain.value
|
||||
)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reset to start of file before upload
|
||||
file_to_upload.file.seek(0)
|
||||
if self.input.outputs.brain.activated:
|
||||
await upload_file(
|
||||
uploadFile=file_to_upload,
|
||||
brain_id=self.input.outputs.brain.value,
|
||||
current_user=self.current_user,
|
||||
chat_id=None,
|
||||
)
|
||||
|
||||
os.remove(new_filename)
|
||||
|
||||
return {"message": f"{file_description} generated successfully"}
|
@ -1,196 +0,0 @@
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from fastapi import UploadFile
|
||||
from langchain.chains import (
|
||||
MapReduceDocumentsChain,
|
||||
ReduceDocumentsChain,
|
||||
StuffDocumentsChain,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_community.document_loaders import UnstructuredPDFLoader
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
InputFile,
|
||||
Inputs,
|
||||
OutputBrain,
|
||||
OutputEmail,
|
||||
Outputs,
|
||||
)
|
||||
from quivr_api.modules.assistant.ito.ito import ITO
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SummaryAssistant(ITO):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: InputAssistant,
|
||||
files: List[UploadFile] = None,
|
||||
current_user: UserIdentity = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
input=input,
|
||||
files=files,
|
||||
current_user=current_user,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def check_input(self):
|
||||
if not self.files:
|
||||
raise ValueError("No file was uploaded")
|
||||
if len(self.files) > 1:
|
||||
raise ValueError("Only one file can be uploaded")
|
||||
if not self.input.inputs.files:
|
||||
raise ValueError("No files key were given in the input")
|
||||
if len(self.input.inputs.files) > 1:
|
||||
raise ValueError("Only one file can be uploaded")
|
||||
if not self.input.inputs.files[0].key == "doc_to_summarize":
|
||||
raise ValueError("The key of the file should be doc_to_summarize")
|
||||
if not self.input.inputs.files[0].value:
|
||||
raise ValueError("No file was uploaded")
|
||||
# Check if name of file is same as the key
|
||||
if not self.input.inputs.files[0].value == self.files[0].filename:
|
||||
raise ValueError(
|
||||
"The key of the file should be the same as the name of the file"
|
||||
)
|
||||
if not (
|
||||
self.input.outputs.brain.activated or self.input.outputs.email.activated
|
||||
):
|
||||
raise ValueError("No output was selected")
|
||||
return True
|
||||
|
||||
async def process_assistant(self):
|
||||
|
||||
try:
|
||||
self.increase_usage_user()
|
||||
except Exception as e:
|
||||
logger.error(f"Error increasing usage: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
# Create a temporary file with the uploaded file as a temporary file and then pass it to the loader
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
|
||||
# Write the file to the temporary file
|
||||
tmp_file.write(self.files[0].file.read())
|
||||
|
||||
# Now pass the path of the temporary file to the loader
|
||||
|
||||
loader = UnstructuredPDFLoader(tmp_file.name)
|
||||
|
||||
tmp_file.close()
|
||||
|
||||
data = loader.load()
|
||||
|
||||
llm = ChatLiteLLM(model="gpt-4o", max_tokens=2000)
|
||||
|
||||
map_template = """The following is a document that has been divided into multiple sections:
|
||||
{docs}
|
||||
|
||||
Please carefully analyze each section and identify the following:
|
||||
|
||||
1. Main Themes: What are the overarching ideas or topics in this section?
|
||||
2. Key Points: What are the most important facts, arguments, or ideas presented in this section?
|
||||
3. Important Information: Are there any crucial details that stand out? This could include data, quotes, specific events, entity, or other relevant information.
|
||||
4. People: Who are the key individuals mentioned in this section? What roles do they play?
|
||||
5. Reasoning: What logic or arguments are used to support the key points?
|
||||
6. Chapters: If the document is divided into chapters, what is the main focus of each chapter?
|
||||
|
||||
Remember to consider the language and context of the document. This will help in understanding the nuances and subtleties of the text."""
|
||||
map_prompt = PromptTemplate.from_template(map_template)
|
||||
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
||||
|
||||
# Reduce
|
||||
reduce_template = """The following is a set of summaries for parts of the document:
|
||||
{docs}
|
||||
Take these and distill it into a final, consolidated summary of the document. Make sure to include the main themes, key points, and important information such as data, quotes,people and specific events.
|
||||
Use markdown such as bold, italics, underlined. For example, **bold**, *italics*, and _underlined_ to highlight key points.
|
||||
Please provide the final summary with sections using bold headers.
|
||||
Sections should always be Summary and Key Points, but feel free to add more sections as needed.
|
||||
Always use bold text for the sections headers.
|
||||
Keep the same language as the documents.
|
||||
Answer:"""
|
||||
reduce_prompt = PromptTemplate.from_template(reduce_template)
|
||||
|
||||
# Run chain
|
||||
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
|
||||
|
||||
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain, document_variable_name="docs"
|
||||
)
|
||||
|
||||
# Combines and iteratively reduces the mapped documents
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
# This is final chain that is called.
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
# If documents exceed context for `StuffDocumentsChain`
|
||||
collapse_documents_chain=combine_documents_chain,
|
||||
# The maximum number of tokens to group documents into.
|
||||
token_max=4000,
|
||||
)
|
||||
|
||||
# Combining documents by mapping a chain over them, then combining results
|
||||
map_reduce_chain = MapReduceDocumentsChain(
|
||||
# Map chain
|
||||
llm_chain=map_chain,
|
||||
# Reduce chain
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
# The variable name in the llm_chain to put the documents in
|
||||
document_variable_name="docs",
|
||||
# Return the results of the map steps in the output
|
||||
return_intermediate_steps=False,
|
||||
)
|
||||
|
||||
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=1000, chunk_overlap=100
|
||||
)
|
||||
split_docs = text_splitter.split_documents(data)
|
||||
|
||||
content = map_reduce_chain.run(split_docs)
|
||||
|
||||
return await self.create_and_upload_processed_file(
|
||||
content, self.files[0].filename, "Summary"
|
||||
)
|
||||
|
||||
|
||||
def summary_inputs():
|
||||
output = AssistantOutput(
|
||||
name="Summary",
|
||||
description="Summarize a set of documents",
|
||||
tags=["new"],
|
||||
input_description="One document to summarize",
|
||||
output_description="A summary of the document with key points and main themes",
|
||||
icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(
|
||||
key="doc_to_summarize",
|
||||
allowed_extensions=["pdf"],
|
||||
required=True,
|
||||
description="The document to summarize",
|
||||
)
|
||||
]
|
||||
),
|
||||
outputs=Outputs(
|
||||
brain=OutputBrain(
|
||||
required=True,
|
||||
description="The brain to which upload the document",
|
||||
type="uuid",
|
||||
),
|
||||
email=OutputEmail(
|
||||
required=True,
|
||||
description="Send the document by email",
|
||||
type="str",
|
||||
),
|
||||
),
|
||||
)
|
||||
return output
|
Binary file not shown.
Before Width: | Height: | Size: 23 KiB |
@ -1 +0,0 @@
|
||||
from .assistant_interface import AssistantInterface
|
@ -1,16 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from quivr_api.modules.assistant.entity.assistant import AssistantEntity
|
||||
|
||||
|
||||
class AssistantInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_all_assistants(self) -> List[AssistantEntity]:
|
||||
"""
|
||||
Get all the knowledge in a brain
|
||||
Args:
|
||||
brain_id (UUID): The id of the brain
|
||||
"""
|
||||
pass
|
@ -0,0 +1,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
|
||||
|
||||
class TasksInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_task(self, task: CreateTask) -> Task:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_task(self, task_id: UUID, user_id: UUID) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tasks_by_user_id(self, user_id: UUID) -> List[Task]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_task(self, task_id: int, task: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
pass
|
82
backend/api/quivr_api/modules/assistant/repository/tasks.py
Normal file
82
backend/api/quivr_api/modules/assistant/repository/tasks.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
from quivr_api.modules.dependencies import BaseRepository
|
||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
)
|
||||
|
||||
|
||||
class TasksRepository(BaseRepository):
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(session)
|
||||
|
||||
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
|
||||
try:
|
||||
task_to_create = Task(
|
||||
assistant_id=task.assistant_id,
|
||||
pretty_id=task.pretty_id,
|
||||
user_id=user_id,
|
||||
settings=task.settings,
|
||||
)
|
||||
self.session.add(task_to_create)
|
||||
await self.session.commit()
|
||||
except exc.IntegrityError:
|
||||
await self.session.rollback()
|
||||
raise Exception()
|
||||
|
||||
await self.session.refresh(task_to_create)
|
||||
return task_to_create
|
||||
|
||||
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
return response.one()
|
||||
|
||||
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
|
||||
query = select(Task).where(Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
return response.all()
|
||||
|
||||
async def delete_task(self, task_id: int, user_id: UUID) -> None:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
if task:
|
||||
await self.session.delete(task)
|
||||
await self.session.commit()
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
async def update_task(self, task_id: int, task_updates: dict) -> None:
|
||||
query = select(Task).where(Task.id == task_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
if task:
|
||||
for key, value in task_updates.items():
|
||||
setattr(task, key, value)
|
||||
await self.session.commit()
|
||||
else:
|
||||
raise Exception("Task not found")
|
||||
|
||||
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
|
||||
path = f"{task.assistant_id}/{task.pretty_id}/output.pdf"
|
||||
|
||||
try:
|
||||
signed_url = generate_file_signed_url(path)
|
||||
if signed_url and "signedURL" in signed_url:
|
||||
return signed_url["signedURL"]
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception:
|
||||
return "error"
|
@ -1,32 +0,0 @@
|
||||
from quivr_api.modules.assistant.entity.assistant import AssistantEntity
|
||||
from quivr_api.modules.assistant.repository.assistant_interface import (
|
||||
AssistantInterface,
|
||||
)
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
|
||||
|
||||
class Assistant(AssistantInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
def get_all_assistants(self):
|
||||
response = self.db.from_("assistants").select("*").execute()
|
||||
|
||||
if response.data:
|
||||
return response.data
|
||||
|
||||
return []
|
||||
|
||||
def get_assistant_by_id(self, ingestion_id) -> AssistantEntity:
|
||||
response = (
|
||||
self.db.from_("assistants")
|
||||
.select("*")
|
||||
.filter("id", "eq", ingestion_id)
|
||||
.execute()
|
||||
)
|
||||
|
||||
if response.data:
|
||||
return AssistantEntity(**response.data[0])
|
||||
|
||||
return None
|
@ -0,0 +1,32 @@
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
from quivr_api.modules.assistant.repository.tasks import TasksRepository
|
||||
from quivr_api.modules.dependencies import BaseService
|
||||
|
||||
|
||||
class TasksService(BaseService[TasksRepository]):
|
||||
repository_cls = TasksRepository
|
||||
|
||||
def __init__(self, repository: TasksRepository):
|
||||
self.repository = repository
|
||||
|
||||
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
|
||||
return await self.repository.create_task(task, user_id)
|
||||
|
||||
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
return await self.repository.get_task_by_id(task_id, user_id)
|
||||
|
||||
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
|
||||
return await self.repository.get_tasks_by_user_id(user_id)
|
||||
|
||||
async def delete_task(self, task_id: int, user_id: UUID) -> None:
|
||||
return await self.repository.delete_task(task_id, user_id)
|
||||
|
||||
async def update_task(self, task_id: int, task: dict) -> None:
|
||||
return await self.repository.update_task(task_id, task)
|
||||
|
||||
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
return await self.repository.get_download_link_task(task_id, user_id)
|
@ -1 +1,5 @@
|
||||
from .brain_routes import brain_router
|
||||
|
||||
__all__ = [
|
||||
"brain_router",
|
||||
]
|
||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainType
|
||||
from quivr_api.modules.brain.entity.integration_brain import IntegrationType
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.prebuilt import ToolExecutor, ToolInvocation
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
@ -7,6 +7,7 @@ from langchain_community.utilities import SQLDatabase
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from quivr_api.modules.brain.integrations.SQL.SQL_connector import SQLConnector
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.brain.repository.integration_brains import IntegrationBrain
|
||||
@ -85,7 +86,6 @@ class SQLBrain(KnowledgeBrainQA, IntegrationBrain):
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
|
@ -12,13 +12,14 @@ from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Post-processing
|
||||
@ -210,13 +211,11 @@ class SelfBrain(KnowledgeBrainQA):
|
||||
return question_rewriter
|
||||
|
||||
def get_chain(self):
|
||||
|
||||
graph = self.create_graph()
|
||||
|
||||
return graph
|
||||
|
||||
def create_graph(self):
|
||||
|
||||
workflow = StateGraph(GraphState)
|
||||
|
||||
# Define the nodes
|
||||
|
@ -2,5 +2,7 @@
|
||||
from .brains_interface import BrainsInterface
|
||||
from .brains_users_interface import BrainsUsersInterface
|
||||
from .brains_vectors_interface import BrainsVectorsInterface
|
||||
from .integration_brains_interface import (IntegrationBrainInterface,
|
||||
IntegrationDescriptionInterface)
|
||||
from .integration_brains_interface import (
|
||||
IntegrationBrainInterface,
|
||||
IntegrationDescriptionInterface,
|
||||
)
|
||||
|
@ -38,7 +38,6 @@ class IntegrationBrainInterface(ABC):
|
||||
|
||||
|
||||
class IntegrationDescriptionInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_integration_description(
|
||||
self, integration_id: UUID
|
||||
|
@ -2,6 +2,7 @@ from typing import List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from quivr_api.middlewares.auth.auth_bearer import get_current_user
|
||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
@ -13,7 +14,7 @@ brain_service = BrainService()
|
||||
|
||||
|
||||
def has_brain_authorization(
|
||||
required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner
|
||||
required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner,
|
||||
):
|
||||
"""
|
||||
Decorator to check if the user has the required role(s) for the brain
|
||||
|
@ -2,6 +2,7 @@ from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import (
|
||||
BrainEntity,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
|
||||
|
||||
|
@ -15,5 +15,7 @@ def get_prompt_to_use_id(
|
||||
return (
|
||||
prompt_id
|
||||
if prompt_id
|
||||
else brain_service.get_brain_prompt_id(brain_id) if brain_id else None
|
||||
else brain_service.get_brain_prompt_id(brain_id)
|
||||
if brain_id
|
||||
else None
|
||||
)
|
||||
|
@ -1,2 +0,0 @@
|
||||
from fastapi import HTTPException
|
||||
from quivr_api.modules.brain.dto.inputs import CreateBrainProperties
|
@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.notification.entity.notification import Notification
|
||||
|
||||
|
@ -2,12 +2,12 @@ from datetime import datetime
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import JSON, TIMESTAMP, Column, Field, Relationship, SQLModel, text
|
||||
from sqlmodel import UUID as PGUUID
|
||||
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import JSON, TIMESTAMP
|
||||
from sqlmodel import UUID as PGUUID
|
||||
from sqlmodel import Column, Field, Relationship, SQLModel, text
|
||||
|
||||
|
||||
class Chat(SQLModel, table=True):
|
||||
|
@ -1 +1 @@
|
||||
from .knowledge_routes import knowledge_router
|
||||
from .knowledge_routes import knowledge_router
|
||||
|
@ -85,3 +85,5 @@ class SupabaseS3Storage(StorageInterface):
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
|
@ -1 +1 @@
|
||||
from .inputs import NotificationUpdatableProperties
|
||||
from .inputs import NotificationUpdatableProperties
|
||||
|
@ -1,6 +1,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from quivr_api.middlewares.auth import AuthBearer
|
||||
from quivr_api.modules.prompt.entity.prompt import (
|
||||
CreatePromptProperties,
|
||||
|
@ -1 +1,7 @@
|
||||
from .prompt import Prompt, PromptStatusEnum, CreatePromptProperties, PromptUpdatableProperties, DeletePromptResponse
|
||||
from .prompt import (
|
||||
CreatePromptProperties,
|
||||
DeletePromptResponse,
|
||||
Prompt,
|
||||
PromptStatusEnum,
|
||||
PromptUpdatableProperties,
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ import requests
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
|
||||
|
@ -50,4 +50,4 @@ successfullConnectionPage = """
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
"""
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, List, Literal, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from quivr_api.modules.sync.entity.sync_models import NotionSyncFile
|
||||
|
||||
|
||||
|
@ -2,20 +2,20 @@ from datetime import datetime, timedelta
|
||||
from typing import List, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.dependencies import (BaseRepository, get_supabase_client)
|
||||
from quivr_api.modules.notification.service.notification_service import \
|
||||
NotificationService
|
||||
from quivr_api.modules.sync.dto.inputs import (SyncsActiveInput,
|
||||
SyncsActiveUpdateInput)
|
||||
from quivr_api.modules.sync.entity.sync_models import (NotionSyncFile,
|
||||
SyncsActive)
|
||||
from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
|
||||
from quivr_api.modules.notification.service.notification_service import (
|
||||
NotificationService,
|
||||
)
|
||||
from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
|
||||
from quivr_api.modules.sync.entity.sync_models import NotionSyncFile, SyncsActive
|
||||
from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface
|
||||
|
||||
notification_service = NotificationService()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .email_sender import EmailSenderTool
|
||||
from .image_generator import ImageGeneratorTool
|
||||
from .web_search import WebSearchTool
|
||||
from .url_reader import URLReaderTool
|
||||
from .email_sender import EmailSenderTool
|
||||
from .web_search import WebSearchTool
|
||||
|
@ -11,6 +11,7 @@ from langchain.pydantic_v1 import Field as FieldV1
|
||||
from langchain_community.document_loaders import PlaywrightURLLoader
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -29,7 +30,6 @@ class URLReaderTool(BaseTool):
|
||||
def _run(
|
||||
self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> Dict:
|
||||
|
||||
loader = PlaywrightURLLoader(urls=[url], remove_selectors=["header", "footer"])
|
||||
data = loader.load()
|
||||
|
||||
|
@ -10,6 +10,7 @@ from langchain.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain.pydantic_v1 import Field as FieldV1
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from multiprocessing import get_logger
|
||||
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
from supabase.client import Client
|
||||
import os
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@ -10,6 +10,7 @@ SIGNED_URL_EXPIRATION_PERIOD_IN_SECONDS = 3600
|
||||
EXTERNAL_SUPABASE_URL = os.getenv("EXTERNAL_SUPABASE_URL", None)
|
||||
SUPABASE_URL = os.getenv("SUPABASE_URL", None)
|
||||
|
||||
|
||||
def generate_file_signed_url(path):
|
||||
supabase_client: Client = get_supabase_client()
|
||||
|
||||
@ -25,7 +26,9 @@ def generate_file_signed_url(path):
|
||||
logger.info("RESPONSE SIGNED URL", response)
|
||||
# Replace in the response the supabase url by the external supabase url in the object signedURL
|
||||
if EXTERNAL_SUPABASE_URL and SUPABASE_URL:
|
||||
response["signedURL"] = response["signedURL"].replace(SUPABASE_URL, EXTERNAL_SUPABASE_URL)
|
||||
response["signedURL"] = response["signedURL"].replace(
|
||||
SUPABASE_URL, EXTERNAL_SUPABASE_URL
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
@ -1,8 +1,7 @@
|
||||
from multiprocessing import get_logger
|
||||
|
||||
from supabase.client import Client
|
||||
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
from supabase.client import Client
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
@ -10,4 +10,3 @@ class UserUpdatableProperties(BaseModel):
|
||||
onboarded: Optional[bool] = None
|
||||
company_size: Optional[str] = None
|
||||
usage_purpose: Optional[str] = None
|
||||
|
||||
|
@ -1 +1 @@
|
||||
from .user_service import UserService
|
||||
from .user_service import UserService
|
||||
|
@ -3,12 +3,11 @@ from uuid import UUID
|
||||
|
||||
from pgvector.sqlalchemy import Vector as PGVector
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.models.settings import settings
|
||||
from sqlalchemy import Column
|
||||
from sqlmodel import JSON, Column, Field, SQLModel, text
|
||||
from sqlmodel import UUID as PGUUID
|
||||
|
||||
from quivr_api.models.settings import settings
|
||||
|
||||
|
||||
class Vector(SQLModel, table=True):
|
||||
__tablename__ = "vectors" # type: ignore
|
||||
|
@ -3,6 +3,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.models.brains_subscription_invitations import BrainSubscription
|
||||
@ -253,7 +254,7 @@ async def accept_invitation(
|
||||
is_default_brain=False,
|
||||
)
|
||||
shared_brain = brain_service.get_brain_by_id(brain_id)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding user to brain: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Error adding user to brain: {e}")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -1,38 +1,46 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from quivr_api.modules.dependencies import get_embedding_client
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
from quivr_api.modules.dependencies import get_embedding_client
|
||||
|
||||
|
||||
def test_ollama_embedding():
|
||||
with patch("quivr_api.modules.dependencies.settings") as mock_settings:
|
||||
mock_settings.ollama_api_base_url = "http://ollama.example.com"
|
||||
mock_settings.azure_openai_embeddings_url = None
|
||||
|
||||
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
|
||||
assert isinstance(embedding_client, OllamaEmbeddings)
|
||||
assert embedding_client.base_url == "http://ollama.example.com"
|
||||
|
||||
|
||||
def test_azure_embedding():
|
||||
with patch("quivr_api.modules.dependencies.settings") as mock_settings:
|
||||
mock_settings.ollama_api_base_url = None
|
||||
mock_settings.azure_openai_embeddings_url = "https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15"
|
||||
|
||||
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
|
||||
assert isinstance(embedding_client, AzureOpenAIEmbeddings)
|
||||
assert embedding_client.azure_endpoint == "https://quivr-test.openai.azure.com"
|
||||
|
||||
|
||||
def test_openai_embedding():
|
||||
with patch("quivr_api.modules.dependencies.settings") as mock_settings, \
|
||||
patch("quivr_api.modules.dependencies.OpenAIEmbeddings") as mock_openai_embeddings:
|
||||
with (
|
||||
patch("quivr_api.modules.dependencies.settings") as mock_settings,
|
||||
patch(
|
||||
"quivr_api.modules.dependencies.OpenAIEmbeddings"
|
||||
) as mock_openai_embeddings,
|
||||
):
|
||||
mock_settings.ollama_api_base_url = None
|
||||
mock_settings.azure_openai_embeddings_url = None
|
||||
|
||||
|
||||
# Create a mock instance for OpenAIEmbeddings
|
||||
mock_openai_instance = MagicMock()
|
||||
mock_openai_embeddings.return_value = mock_openai_instance
|
||||
|
||||
|
||||
embedding_client = get_embedding_client()
|
||||
|
||||
|
||||
assert embedding_client == mock_openai_instance
|
||||
|
@ -1,12 +1,11 @@
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from quivr_core import Brain
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
|
||||
if __name__ == "__main__":
|
||||
brain = Brain.from_files(
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Generator, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from quivr_core.models import ChatMessage
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ class ChatHistory:
|
||||
"""
|
||||
# Reverse the chat_history, newest first
|
||||
it = iter(self.get_chat_history(newest_first=True))
|
||||
for ai_message, human_message in zip(it, it):
|
||||
for ai_message, human_message in zip(it, it, strict=False):
|
||||
assert isinstance(
|
||||
human_message.msg, HumanMessage
|
||||
), f"msg {human_message} is not HumanMessage"
|
||||
|
@ -55,7 +55,7 @@ class ChatLLM:
|
||||
filtered_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
|
||||
|
||||
return filtered_chat_history
|
||||
|
||||
def build_chain(self):
|
||||
|
@ -14,19 +14,20 @@ logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
class MegaparseProcessor(ProcessorBase):
|
||||
'''
|
||||
"""
|
||||
Megaparse processor for PDF files.
|
||||
|
||||
|
||||
It can be used to parse PDF files and split them into chunks.
|
||||
|
||||
|
||||
It comes from the megaparse library.
|
||||
|
||||
|
||||
## Installation
|
||||
```bash
|
||||
pip install megaparse
|
||||
```
|
||||
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
supported_extensions = [FileExtension.pdf]
|
||||
|
||||
def __init__(
|
||||
|
@ -37,6 +37,7 @@ def model_supports_function_calling(model_name: str):
|
||||
]
|
||||
return model_name in models_supporting_function_calls
|
||||
|
||||
|
||||
def format_history_to_openai_mesages(
|
||||
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
||||
) -> List[BaseMessage]:
|
||||
@ -125,7 +126,11 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
|
||||
)
|
||||
|
||||
if model_supports_function_calling(model_name):
|
||||
if 'tool_calls' in raw_response["answer"] and raw_response["answer"].tool_calls and "citations" in raw_response["answer"].tool_calls[-1]["args"]:
|
||||
if (
|
||||
"tool_calls" in raw_response["answer"]
|
||||
and raw_response["answer"].tool_calls
|
||||
and "citations" in raw_response["answer"].tool_calls[-1]["args"]
|
||||
):
|
||||
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||
metadata.citations = citations
|
||||
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
||||
@ -147,7 +152,7 @@ def combine_documents(
|
||||
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
||||
):
|
||||
# for each docs, add an index in the metadata to be able to cite the sources
|
||||
for doc, index in zip(docs, range(len(docs))):
|
||||
for doc, index in zip(docs, range(len(docs)), strict=False):
|
||||
doc.metadata["index"] = index
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
@ -5,7 +5,6 @@ from uuid import uuid4
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.processor.implementations.default import MarkdownProcessor
|
||||
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.processor.implementations.default import DOCXProcessor
|
||||
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.processor.implementations.default import EpubProcessor
|
||||
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.processor.implementations.default import ODTProcessor
|
||||
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.processor.implementations.default import UnstructuredPDFProcessor
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
|
||||
from quivr_core.files.file import FileExtension
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
|
||||
@ -7,7 +6,6 @@ from quivr_core.processor.processor_base import ProcessorBase
|
||||
@pytest.mark.base
|
||||
def test___build_processor():
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
from quivr_core.processor.implementations.default import _build_processor
|
||||
|
||||
cls = _build_processor("TestCLS", BaseLoader, [FileExtension.txt])
|
||||
|
@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from quivr_core.files.file import FileExtension
|
||||
from quivr_core.processor.implementations.simple_txt_processor import (
|
||||
SimpleTxtProcessor,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
|
||||
from quivr_core.processor.implementations.tika_processor import TikaProcessor
|
||||
|
||||
# TODO: TIKA server should be set
|
||||
|
@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
|
||||
from quivr_core import ChatLLM
|
||||
|
||||
|
||||
|
@ -3,7 +3,6 @@ import os
|
||||
import pytest
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
||||
|
@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
|
||||
from quivr_core.utils import (
|
||||
get_prev_message_str,
|
||||
model_supports_function_calling,
|
||||
|
@ -173,6 +173,7 @@ debugpy==1.8.5
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
defusedxml==0.7.1
|
||||
# via fpdf2
|
||||
# via langchain-anthropic
|
||||
# via nbconvert
|
||||
deprecated==1.2.14
|
||||
@ -238,10 +239,11 @@ flatbuffers==24.3.25
|
||||
flower==2.0.1
|
||||
# via quivr-worker
|
||||
fonttools==4.53.1
|
||||
# via fpdf2
|
||||
# via matplotlib
|
||||
# via pdf2docx
|
||||
fpdf==1.7.2
|
||||
# via quivr-api
|
||||
fpdf2==2.7.9
|
||||
# via quivr-worker
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
@ -747,6 +749,7 @@ pgvector==0.3.2
|
||||
pikepdf==9.1.1
|
||||
# via unstructured
|
||||
pillow==10.2.0
|
||||
# via fpdf2
|
||||
# via layoutparser
|
||||
# via llama-index-core
|
||||
# via matplotlib
|
||||
|
@ -150,6 +150,7 @@ debugpy==1.8.5
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
defusedxml==0.7.1
|
||||
# via fpdf2
|
||||
# via langchain-anthropic
|
||||
# via nbconvert
|
||||
deprecated==1.2.14
|
||||
@ -200,10 +201,11 @@ flatbuffers==24.3.25
|
||||
flower==2.0.1
|
||||
# via quivr-worker
|
||||
fonttools==4.53.1
|
||||
# via fpdf2
|
||||
# via matplotlib
|
||||
# via pdf2docx
|
||||
fpdf==1.7.2
|
||||
# via quivr-api
|
||||
fpdf2==2.7.9
|
||||
# via quivr-worker
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
@ -650,6 +652,7 @@ pgvector==0.3.2
|
||||
pikepdf==9.1.1
|
||||
# via unstructured
|
||||
pillow==10.2.0
|
||||
# via fpdf2
|
||||
# via layoutparser
|
||||
# via llama-index-core
|
||||
# via matplotlib
|
||||
|
79
backend/supabase/migrations/20240911145305_gtasks.sql
Normal file
79
backend/supabase/migrations/20240911145305_gtasks.sql
Normal file
@ -0,0 +1,79 @@
|
||||
create table "public"."tasks" (
|
||||
"id" bigint generated by default as identity not null,
|
||||
"pretty_id" text,
|
||||
"user_id" uuid not null default auth.uid(),
|
||||
"status" text,
|
||||
"creation_time" timestamp with time zone default (now() AT TIME ZONE 'utc'::text),
|
||||
"answer_raw" jsonb,
|
||||
"answer_pretty" text
|
||||
);
|
||||
|
||||
|
||||
alter table "public"."tasks" enable row level security;
|
||||
|
||||
CREATE UNIQUE INDEX tasks_pkey ON public.tasks USING btree (id);
|
||||
|
||||
alter table "public"."tasks" add constraint "tasks_pkey" PRIMARY KEY using index "tasks_pkey";
|
||||
|
||||
alter table "public"."tasks" add constraint "tasks_user_id_fkey" FOREIGN KEY (user_id) REFERENCES auth.users(id) ON UPDATE CASCADE ON DELETE CASCADE not valid;
|
||||
|
||||
alter table "public"."tasks" validate constraint "tasks_user_id_fkey";
|
||||
|
||||
grant delete on table "public"."tasks" to "anon";
|
||||
|
||||
grant insert on table "public"."tasks" to "anon";
|
||||
|
||||
grant references on table "public"."tasks" to "anon";
|
||||
|
||||
grant select on table "public"."tasks" to "anon";
|
||||
|
||||
grant trigger on table "public"."tasks" to "anon";
|
||||
|
||||
grant truncate on table "public"."tasks" to "anon";
|
||||
|
||||
grant update on table "public"."tasks" to "anon";
|
||||
|
||||
grant delete on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant insert on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant references on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant select on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant trigger on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant truncate on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant update on table "public"."tasks" to "authenticated";
|
||||
|
||||
grant delete on table "public"."tasks" to "service_role";
|
||||
|
||||
grant insert on table "public"."tasks" to "service_role";
|
||||
|
||||
grant references on table "public"."tasks" to "service_role";
|
||||
|
||||
grant select on table "public"."tasks" to "service_role";
|
||||
|
||||
grant trigger on table "public"."tasks" to "service_role";
|
||||
|
||||
grant truncate on table "public"."tasks" to "service_role";
|
||||
|
||||
grant update on table "public"."tasks" to "service_role";
|
||||
|
||||
create policy "allow_user_all_tasks"
|
||||
on "public"."tasks"
|
||||
as permissive
|
||||
for all
|
||||
to public
|
||||
using ((user_id = ( SELECT auth.uid() AS uid)));
|
||||
|
||||
|
||||
create policy "tasks"
|
||||
on "public"."tasks"
|
||||
as permissive
|
||||
for all
|
||||
to service_role;
|
||||
|
||||
|
||||
|
11
backend/supabase/migrations/20240918094405_assistants.sql
Normal file
11
backend/supabase/migrations/20240918094405_assistants.sql
Normal file
@ -0,0 +1,11 @@
|
||||
alter table "public"."tasks" drop column "answer_pretty";
|
||||
|
||||
alter table "public"."tasks" drop column "answer_raw";
|
||||
|
||||
alter table "public"."tasks" add column "answer" text;
|
||||
|
||||
alter table "public"."tasks" add column "assistant_id" bigint not null;
|
||||
|
||||
alter table "public"."tasks" add column "settings" jsonb;
|
||||
|
||||
|
@ -13,10 +13,11 @@ dependencies = [
|
||||
"playwright>=1.0.0",
|
||||
"openai>=1.0.0",
|
||||
"flower>=2.0.1",
|
||||
"torch==2.4.0; platform_machine != 'x86_64'" ,
|
||||
"torch==2.4.0; platform_machine != 'x86_64'",
|
||||
"torch==2.4.0+cpu; platform_machine == 'x86_64'",
|
||||
"torchvision==0.19.0; platform_machine != 'x86_64'",
|
||||
"torchvision==0.19.0+cpu; platform_machine == 'x86_64'",
|
||||
"fpdf2>=2.7.9",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.11"
|
||||
|
40
backend/worker/quivr_worker/assistants/assistants.py
Normal file
40
backend/worker/quivr_worker/assistants/assistants.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
from quivr_api.modules.assistant.services.tasks_service import TasksService
|
||||
from quivr_api.modules.upload.service.upload_file import (
|
||||
upload_file_storage,
|
||||
)
|
||||
|
||||
from quivr_worker.utils.pdf_generator.pdf_generator import PDFGenerator, PDFModel
|
||||
|
||||
|
||||
async def process_assistant(
|
||||
assistant_id: str,
|
||||
notification_uuid: str,
|
||||
task_id: int,
|
||||
tasks_service: TasksService,
|
||||
user_id: str,
|
||||
):
|
||||
task = await tasks_service.get_task_by_id(task_id, user_id) # type: ignore
|
||||
|
||||
await tasks_service.update_task(task_id, {"status": "in_progress"})
|
||||
|
||||
print(task)
|
||||
|
||||
task_result = {"status": "completed", "answer": "#### Assistant answer"}
|
||||
|
||||
output_dir = f"{assistant_id}/{notification_uuid}"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = f"{output_dir}/output.pdf"
|
||||
|
||||
generated_pdf = PDFGenerator(PDFModel(title="Test", content="Test"))
|
||||
generated_pdf.print_pdf()
|
||||
generated_pdf.output(output_path)
|
||||
|
||||
with open(output_path, "rb") as file:
|
||||
await upload_file_storage(file, output_path)
|
||||
|
||||
# Now delete the file
|
||||
os.remove(output_path)
|
||||
|
||||
await tasks_service.update_task(task_id, task_result)
|
@ -8,6 +8,8 @@ from dotenv import load_dotenv
|
||||
from quivr_api.celery_config import celery
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import settings
|
||||
from quivr_api.modules.assistant.repository.tasks import TasksRepository
|
||||
from quivr_api.modules.assistant.services.tasks_service import TasksService
|
||||
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector
|
||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
@ -29,6 +31,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel import Session, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_worker.assistants.assistants import process_assistant
|
||||
from quivr_worker.check_premium import check_is_premium
|
||||
from quivr_worker.process.process_s3_file import process_uploaded_file
|
||||
from quivr_worker.process.process_url import process_url_func
|
||||
@ -39,7 +42,7 @@ from quivr_worker.syncs.process_active_syncs import (
|
||||
process_sync,
|
||||
)
|
||||
from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async
|
||||
from quivr_worker.utils import _patch_json
|
||||
from quivr_worker.utils.utils import _patch_json
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -91,6 +94,63 @@ def init_worker(**kwargs):
|
||||
)
|
||||
|
||||
|
||||
@celery.task(
|
||||
retries=3,
|
||||
default_retry_delay=1,
|
||||
name="process_assistant_task",
|
||||
autoretry_for=(Exception,),
|
||||
)
|
||||
def process_assistant_task(
|
||||
assistant_id: str,
|
||||
notification_uuid: str,
|
||||
task_id: int,
|
||||
user_id: str,
|
||||
):
|
||||
logger.info(
|
||||
f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}"
|
||||
)
|
||||
print("process_assistant_task")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
aprocess_assistant_task(
|
||||
assistant_id,
|
||||
notification_uuid,
|
||||
task_id,
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def aprocess_assistant_task(
|
||||
assistant_id: str,
|
||||
notification_uuid: str,
|
||||
task_id: int,
|
||||
user_id: str,
|
||||
):
|
||||
async with AsyncSession(async_engine) as async_session:
|
||||
try:
|
||||
await async_session.execute(
|
||||
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
||||
)
|
||||
tasks_repository = TasksRepository(async_session)
|
||||
tasks_service = TasksService(tasks_repository)
|
||||
|
||||
await process_assistant(
|
||||
assistant_id,
|
||||
notification_uuid,
|
||||
task_id,
|
||||
tasks_service,
|
||||
user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await async_session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
await async_session.close()
|
||||
|
||||
|
||||
@celery.task(
|
||||
retries=3,
|
||||
default_retry_delay=1,
|
||||
@ -111,10 +171,6 @@ def process_file_task(
|
||||
if async_engine is None:
|
||||
init_worker()
|
||||
|
||||
logger.info(
|
||||
f"Task process_file started for file_name={file_name}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
aprocess_file_task(
|
||||
|
@ -9,7 +9,7 @@ from uuid import UUID
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
|
||||
from quivr_worker.utils import get_tmp_name
|
||||
from quivr_worker.utils.utils import get_tmp_name
|
||||
|
||||
logger = get_logger("celery_worker")
|
||||
|
||||
|
0
backend/worker/quivr_worker/utils/__init__.py
Normal file
0
backend/worker/quivr_worker/utils/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user