feat(assistants): mock api (#3195)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-09-18 12:30:48 +02:00 committed by GitHub
parent 4390d318a2
commit 282fa0e3f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
109 changed files with 923 additions and 883 deletions

1
.gitignore vendored
View File

@ -102,3 +102,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
# Tox
.tox
Pipfile
*.pkl

View File

@ -19,7 +19,6 @@ dependencies = [
"pydantic-settings>=2.4.0",
"python-dotenv>=1.0.1",
"unidecode>=1.3.8",
"fpdf>=1.7.2",
"colorlog>=6.8.2",
"posthog>=3.5.0",
"pyinstrument>=4.7.2",

View File

@ -4,7 +4,7 @@ from logging.handlers import RotatingFileHandler
from colorlog import (
ColoredFormatter,
) # You need to install this package: pip install colorlog
)
def get_logger(logger_name, log_file="application.log"):

View File

@ -4,6 +4,7 @@ from typing import Optional
from jose import jwt
from jose.exceptions import JWTError
from quivr_api.modules.user.entity.user_identity import UserIdentity
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")

View File

@ -1,6 +1,7 @@
from uuid import UUID
from pydantic import BaseModel, ConfigDict
from quivr_api.logger import get_logger
logger = get_logger(__name__)

View File

@ -1,6 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
from quivr_api.modules.analytics.entity.analytics import Range
from quivr_api.modules.analytics.service.analytics_service import AnalyticsService

View File

@ -1,16 +1,20 @@
from datetime import date
from enum import IntEnum
from typing import List
from pydantic import BaseModel
from datetime import date
class Range(IntEnum):
WEEK = 7
MONTH = 30
QUARTER = 90
class Usage(BaseModel):
date: date
usage_count: int
class BrainsUsages(BaseModel):
usages: List[Usage]
usages: List[Usage]

View File

@ -11,5 +11,4 @@ class AnalyticsService:
self.repository = Analytics()
def get_brains_usages(self, user_id, graph_range, brain_id=None):
return self.repository.get_brains_usages(user_id, graph_range, brain_id)

View File

@ -3,6 +3,7 @@ from typing import List
from uuid import uuid4
from fastapi import APIRouter, Depends
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.api_key.dto.outputs import ApiKeyInfo

View File

@ -1,6 +1,7 @@
from datetime import datetime
from fastapi import HTTPException
from quivr_api.logger import get_logger
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
from quivr_api.modules.api_key.repository.api_keys import ApiKeys

View File

@ -1 +1,6 @@
# noqa:
from .assistant_routes import assistant_router
__all__ = [
"assistant_router",
]

View File

@ -1,63 +1,176 @@
from typing import List
import io
from typing import Annotated, List
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
from quivr_api.modules.assistant.controller.assistants_definition import (
assistants,
validate_assistant_input,
)
from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant
from quivr_api.modules.assistant.dto.outputs import AssistantOutput
from quivr_api.modules.assistant.ito.difference import DifferenceAssistant
from quivr_api.modules.assistant.ito.summary import SummaryAssistant, summary_inputs
from quivr_api.modules.assistant.service.assistant import Assistant
from quivr_api.modules.assistant.entity.assistant_entity import (
AssistantSettings,
)
from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.upload.service.upload_file import (
upload_file_storage,
)
from quivr_api.modules.user.entity.user_identity import UserIdentity
assistant_router = APIRouter()
logger = get_logger(__name__)
assistant_service = Assistant()
assistant_router = APIRouter()
TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))]
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
@assistant_router.get(
"/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
)
async def list_assistants(
async def get_assistants(
request: Request,
current_user: UserIdentity = Depends(get_current_user),
) -> List[AssistantOutput]:
"""
Retrieve and list all the knowledge in a brain.
"""
logger.info("Getting assistants")
summary = summary_inputs()
# difference = difference_inputs()
# crawler = crawler_inputs()
return [summary]
return assistants
@assistant_router.get(
"/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
)
async def get_tasks(
request: Request,
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
):
logger.info("Getting tasks")
return await tasks_service.get_tasks_by_user_id(current_user.id)
@assistant_router.post(
"/assistant/process",
"/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
)
async def create_task(
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
request: Request,
input: InputAssistant,
files: List[UploadFile] = None,
):
assistant = next(
(assistant for assistant in assistants if assistant.id == input.id), None
)
if assistant is None:
raise HTTPException(status_code=404, detail="Assistant not found")
is_valid, validation_errors = validate_assistant_input(input, assistant)
if not is_valid:
for error in validation_errors:
print(error)
raise HTTPException(status_code=400, detail=error)
else:
print("Assistant input is valid.")
notification_uuid = uuid4()
# Process files dynamically
for upload_file in files:
file_name_path = f"{input.id}/{notification_uuid}/{upload_file.filename}"
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
try:
await upload_file_storage(buff_reader, file_name_path)
except Exception as e:
logger.exception(f"Exception in upload_route {e}")
raise HTTPException(
status_code=500, detail=f"Failed to upload file to storage. {e}"
)
task = CreateTask(
assistant_id=input.id,
pretty_id=str(notification_uuid),
settings=input.model_dump(mode="json"),
)
task_created = await tasks_service.create_task(task, current_user.id)
celery.send_task(
"process_assistant_task",
kwargs={
"assistant_id": input.id,
"notification_uuid": notification_uuid,
"task_id": task_created.id,
"user_id": str(current_user.id),
},
)
return task_created
@assistant_router.get(
"/assistants/task/{task_id}",
dependencies=[Depends(AuthBearer())],
tags=["Assistant"],
)
async def process_assistant(
input: InputAssistant,
files: List[UploadFile] = None,
current_user: UserIdentity = Depends(get_current_user),
async def get_task(
request: Request,
task_id: str,
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
):
if input.name.lower() == "summary":
summary_assistant = SummaryAssistant(
input=input, files=files, current_user=current_user
)
try:
summary_assistant.check_input()
return await summary_assistant.process_assistant()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
elif input.name.lower() == "difference":
difference_assistant = DifferenceAssistant(
input=input, files=files, current_user=current_user
)
try:
difference_assistant.check_input()
return await difference_assistant.process_assistant()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"message": "Assistant not found"}
return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore
@assistant_router.delete(
"/assistants/task/{task_id}",
dependencies=[Depends(AuthBearer())],
tags=["Assistant"],
)
async def delete_task(
request: Request,
task_id: int,
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
):
return await tasks_service.delete_task(task_id, current_user.id)
@assistant_router.get(
"/assistants/task/{task_id}/download",
dependencies=[Depends(AuthBearer())],
tags=["Assistant"],
)
async def get_download_link_task(
request: Request,
task_id: int,
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
):
return await tasks_service.get_download_link_task(task_id, current_user.id)
@assistant_router.get(
"/assistants/{assistant_id}/config",
dependencies=[Depends(AuthBearer())],
tags=["Assistant"],
response_model=AssistantSettings,
summary="Retrieve assistant configuration",
description="Get the settings and file requirements for the specified assistant.",
)
async def get_assistant_config(
assistant_id: int,
current_user: UserIdentityDep,
):
assistant = next(
(assistant for assistant in assistants if assistant.id == assistant_id), None
)
if assistant is None:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant.settings

View File

@ -0,0 +1,201 @@
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.modules.assistant.dto.outputs import (
AssistantOutput,
InputFile,
Inputs,
Pricing,
)
def validate_assistant_input(
assistant_input: InputAssistant, assistant_output: AssistantOutput
):
errors = []
# Validate files
if assistant_output.inputs.files:
required_files = [
file for file in assistant_output.inputs.files if file.required
]
input_files = {
file_input.key for file_input in (assistant_input.inputs.files or [])
}
for req_file in required_files:
if req_file.key not in input_files:
errors.append(f"Missing required file input: {req_file.key}")
# Validate URLs
if assistant_output.inputs.urls:
required_urls = [url for url in assistant_output.inputs.urls if url.required]
input_urls = {
url_input.key for url_input in (assistant_input.inputs.urls or [])
}
for req_url in required_urls:
if req_url.key not in input_urls:
errors.append(f"Missing required URL input: {req_url.key}")
# Validate texts
if assistant_output.inputs.texts:
required_texts = [
text for text in assistant_output.inputs.texts if text.required
]
input_texts = {
text_input.key for text_input in (assistant_input.inputs.texts or [])
}
for req_text in required_texts:
if req_text.key not in input_texts:
errors.append(f"Missing required text input: {req_text.key}")
else:
# Validate regex if applicable
req_text_val = next(
(t for t in assistant_output.inputs.texts if t.key == req_text.key),
None,
)
if req_text_val and req_text_val.validation_regex:
import re
input_value = next(
(
t.value
for t in assistant_input.inputs.texts
if t.key == req_text.key
),
"",
)
if not re.match(req_text_val.validation_regex, input_value):
errors.append(
f"Text input '{req_text.key}' does not match the required format."
)
# Validate booleans
if assistant_output.inputs.booleans:
required_booleans = [b for b in assistant_output.inputs.booleans if b.required]
input_booleans = {
b_input.key for b_input in (assistant_input.inputs.booleans or [])
}
for req_bool in required_booleans:
if req_bool.key not in input_booleans:
errors.append(f"Missing required boolean input: {req_bool.key}")
# Validate numbers
if assistant_output.inputs.numbers:
required_numbers = [n for n in assistant_output.inputs.numbers if n.required]
input_numbers = {
n_input.key for n_input in (assistant_input.inputs.numbers or [])
}
for req_number in required_numbers:
if req_number.key not in input_numbers:
errors.append(f"Missing required number input: {req_number.key}")
else:
# Validate min and max
input_value = next(
(
n.value
for n in assistant_input.inputs.numbers
if n.key == req_number.key
),
None,
)
if req_number.min is not None and input_value < req_number.min:
errors.append(
f"Number input '{req_number.key}' is below minimum value."
)
if req_number.max is not None and input_value > req_number.max:
errors.append(
f"Number input '{req_number.key}' exceeds maximum value."
)
# Validate select_texts
if assistant_output.inputs.select_texts:
required_select_texts = [
st for st in assistant_output.inputs.select_texts if st.required
]
input_select_texts = {
st_input.key for st_input in (assistant_input.inputs.select_texts or [])
}
for req_select in required_select_texts:
if req_select.key not in input_select_texts:
errors.append(f"Missing required select text input: {req_select.key}")
else:
input_value = next(
(
st.value
for st in assistant_input.inputs.select_texts
if st.key == req_select.key
),
None,
)
if input_value not in req_select.options:
errors.append(f"Invalid option for select text '{req_select.key}'.")
# Validate select_numbers
if assistant_output.inputs.select_numbers:
required_select_numbers = [
sn for sn in assistant_output.inputs.select_numbers if sn.required
]
input_select_numbers = {
sn_input.key for sn_input in (assistant_input.inputs.select_numbers or [])
}
for req_select in required_select_numbers:
if req_select.key not in input_select_numbers:
errors.append(f"Missing required select number input: {req_select.key}")
else:
input_value = next(
(
sn.value
for sn in assistant_input.inputs.select_numbers
if sn.key == req_select.key
),
None,
)
if input_value not in req_select.options:
errors.append(
f"Invalid option for select number '{req_select.key}'."
)
# Validate brain input
if assistant_output.inputs.brain and assistant_output.inputs.brain.required:
if not assistant_input.inputs.brain or not assistant_input.inputs.brain.value:
errors.append("Missing required brain input.")
if errors:
return False, errors
else:
return True, None
assistant1 = AssistantOutput(
id=1,
name="Assistant 1",
description="Assistant 1 description",
pricing=Pricing(),
tags=["tag1", "tag2"],
input_description="Input description",
output_description="Output description",
inputs=Inputs(
files=[
InputFile(key="file_1", description="File description"),
InputFile(key="file_2", description="File description"),
],
),
icon_url="https://example.com/icon.png",
)
assistant2 = AssistantOutput(
id=2,
name="Assistant 2",
description="Assistant 2 description",
pricing=Pricing(),
tags=["tag1", "tag2"],
input_description="Input description",
output_description="Output description",
icon_url="https://example.com/icon.png",
inputs=Inputs(
files=[
InputFile(key="file_1", description="File description"),
InputFile(key="file_2", description="File description"),
],
),
)
assistants = [assistant1, assistant2]

View File

@ -1,16 +1,16 @@
import json
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, model_validator, root_validator
from pydantic import BaseModel, root_validator
class EmailInput(BaseModel):
activated: bool
class CreateTask(BaseModel):
pretty_id: str
assistant_id: int
settings: dict
class BrainInput(BaseModel):
activated: Optional[bool] = False
value: Optional[UUID] = None
@root_validator(pre=True)
@ -64,19 +64,10 @@ class Inputs(BaseModel):
numbers: Optional[List[InputNumber]] = None
select_texts: Optional[List[InputSelectText]] = None
select_numbers: Optional[List[InputSelectNumber]] = None
class Outputs(BaseModel):
email: Optional[EmailInput] = None
brain: Optional[BrainInput] = None
class InputAssistant(BaseModel):
id: int
name: str
inputs: Inputs
outputs: Outputs
@model_validator(mode="before")
@classmethod
def to_py_dict(cls, data):
return json.loads(data)

View File

@ -3,6 +3,12 @@ from typing import List, Optional
from pydantic import BaseModel
class BrainInput(BaseModel):
required: Optional[bool] = True
description: str
type: str
class InputFile(BaseModel):
key: str
allowed_extensions: Optional[List[str]] = ["pdf"]
@ -63,23 +69,7 @@ class Inputs(BaseModel):
numbers: Optional[List[InputNumber]] = None
select_texts: Optional[List[InputSelectText]] = None
select_numbers: Optional[List[InputSelectNumber]] = None
class OutputEmail(BaseModel):
required: Optional[bool] = True
description: str
type: str
class OutputBrain(BaseModel):
required: Optional[bool] = True
description: str
type: str
class Outputs(BaseModel):
email: Optional[OutputEmail] = None
brain: Optional[OutputBrain] = None
brain: Optional[BrainInput] = None
class Pricing(BaseModel):
@ -88,6 +78,7 @@ class Pricing(BaseModel):
class AssistantOutput(BaseModel):
id: int
name: str
description: str
pricing: Optional[Pricing] = Pricing()
@ -95,5 +86,4 @@ class AssistantOutput(BaseModel):
input_description: str
output_description: str
inputs: Inputs
outputs: Outputs
icon_url: Optional[str] = None

View File

@ -1 +0,0 @@
from .assistant import AssistantEntity

View File

@ -1,11 +0,0 @@
from uuid import UUID
from pydantic import BaseModel
class AssistantEntity(BaseModel):
id: UUID
name: str
brain_id_required: bool
file_1_required: bool
url_required: bool

View File

@ -0,0 +1,33 @@
from typing import Any, List, Optional
from pydantic import BaseModel
class AssistantFileRequirement(BaseModel):
name: str
description: Optional[str] = None
required: bool = True
accepted_types: Optional[List[str]] = None # e.g., ['text/csv', 'application/json']
class AssistantInput(BaseModel):
name: str
description: str
type: str # e.g., 'boolean', 'uuid', 'string'
required: bool = True
regex: Optional[str] = None
options: Optional[List[Any]] = None # For predefined choices
default: Optional[Any] = None
class AssistantSettings(BaseModel):
inputs: List[AssistantInput]
files: Optional[List[AssistantFileRequirement]] = None
class Assistant(BaseModel):
id: int
name: str
description: str
settings: AssistantSettings
required_files: Optional[List[str]] = None # List of required file names

View File

@ -0,0 +1,34 @@
from datetime import datetime
from typing import Dict
from uuid import UUID
from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text
class Task(SQLModel, table=True):
__tablename__ = "tasks" # type: ignore
id: int | None = Field(
default=None,
sa_column=Column(
BigInteger,
primary_key=True,
autoincrement=True,
),
)
assistant_id: int
pretty_id: str
user_id: UUID
status: str = Field(default="pending")
creation_time: datetime | None = Field(
default=None,
sa_column=Column(
TIMESTAMP(timezone=False),
server_default=text("CURRENT_TIMESTAMP"),
),
)
settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
answer: str | None = Field(default=None)
class Config:
arbitrary_types_allowed = True

View File

@ -1,73 +0,0 @@
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from quivr_api.logger import get_logger
from quivr_api.modules.assistant.dto.outputs import (
AssistantOutput,
Inputs,
InputUrl,
OutputBrain,
OutputEmail,
Outputs,
)
from quivr_api.modules.assistant.ito.ito import ITO
logger = get_logger(__name__)
class CrawlerAssistant(ITO):
def __init__(
self,
**kwargs,
):
super().__init__(
**kwargs,
)
async def process_assistant(self):
url = self.url
loader = RecursiveUrlLoader(
url=url, max_depth=2, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
nice_url = url.split("://")[1].replace("/", "_").replace(".", "_")
nice_url += ".txt"
for docs in docs:
await self.create_and_upload_processed_file(
docs.page_content, nice_url, "Crawler"
)
def crawler_inputs():
output = AssistantOutput(
name="Crawler",
description="Crawls a website and extracts the text from the pages",
tags=["new"],
input_description="One URL to crawl",
output_description="Text extracted from the pages",
inputs=Inputs(
urls=[
InputUrl(
key="url",
required=True,
description="The URL to crawl",
)
],
),
outputs=Outputs(
brain=OutputBrain(
required=True,
description="The brain to which upload the document",
type="uuid",
),
email=OutputEmail(
required=True,
description="Send the document by email",
type="str",
),
),
)
return output

View File

@ -1,171 +0,0 @@
import os
import tempfile
from typing import List
from fastapi import UploadFile
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from llama_parse import LlamaParse
from quivr_api.logger import get_logger
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.modules.assistant.dto.outputs import (
AssistantOutput,
InputFile,
Inputs,
OutputBrain,
OutputEmail,
Outputs,
)
from quivr_api.modules.assistant.ito.ito import ITO
from quivr_api.modules.user.entity.user_identity import UserIdentity
logger = get_logger(__name__)
class DifferenceAssistant(ITO):
def __init__(
self,
input: InputAssistant,
files: List[UploadFile] = None,
current_user: UserIdentity = None,
**kwargs,
):
super().__init__(
input=input,
files=files,
current_user=current_user,
**kwargs,
)
def check_input(self):
if not self.files:
raise ValueError("No file was uploaded")
if len(self.files) != 2:
raise ValueError("Only two files can be uploaded")
if not self.input.inputs.files:
raise ValueError("No files key were given in the input")
if len(self.input.inputs.files) != 2:
raise ValueError("Only two files can be uploaded")
if not self.input.inputs.files[0].key == "doc_1":
raise ValueError("The key of the first file should be doc_1")
if not self.input.inputs.files[1].key == "doc_2":
raise ValueError("The key of the second file should be doc_2")
if not self.input.inputs.files[0].value:
raise ValueError("No file was uploaded")
if not (
self.input.outputs.brain.activated or self.input.outputs.email.activated
):
raise ValueError("No output was selected")
return True
async def process_assistant(self):
document_1 = self.files[0]
document_2 = self.files[1]
# Get the file extensions
document_1_ext = os.path.splitext(document_1.filename)[1]
document_2_ext = os.path.splitext(document_2.filename)[1]
# Create temporary files with the same extension as the original files
document_1_tmp = tempfile.NamedTemporaryFile(
suffix=document_1_ext, delete=False
)
document_2_tmp = tempfile.NamedTemporaryFile(
suffix=document_2_ext, delete=False
)
document_1_tmp.write(document_1.file.read())
document_2_tmp.write(document_2.file.read())
parser = LlamaParse(
result_type="markdown" # "markdown" and "text" are available
)
document_1_llama_parsed = parser.load_data(document_1_tmp.name)
document_2_llama_parsed = parser.load_data(document_2_tmp.name)
document_1_tmp.close()
document_2_tmp.close()
document_1_to_langchain = document_1_llama_parsed[0].to_langchain_format()
document_2_to_langchain = document_2_llama_parsed[0].to_langchain_format()
llm = ChatLiteLLM(model="gpt-4o")
human_prompt = """Given the following two documents, find the difference between them:
Document 1:
{document_1}
Document 2:
{document_2}
Difference:
"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(human_prompt)
system_message_template = """
You are an expert in finding the difference between two documents. You look deeply into what makes the two documents different and provide a detailed analysis if needed of the differences between the two documents.
If no differences are found, simply say that there are no differences.
"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(human_prompt),
]
)
final_inputs = {
"document_1": document_1_to_langchain.page_content,
"document_2": document_2_to_langchain.page_content,
}
output_parser = StrOutputParser()
chain = ANSWER_PROMPT | llm | output_parser
result = chain.invoke(final_inputs)
return result
def difference_inputs():
output = AssistantOutput(
name="difference",
description="Finds the difference between two sets of documents",
tags=["new"],
input_description="Two documents to compare",
output_description="The difference between the two documents",
icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
inputs=Inputs(
files=[
InputFile(
key="doc_1",
allowed_extensions=["pdf"],
required=True,
description="The first document to compare",
),
InputFile(
key="doc_2",
allowed_extensions=["pdf"],
required=True,
description="The second document to compare",
),
]
),
outputs=Outputs(
brain=OutputBrain(
required=True,
description="The brain to which upload the document",
type="uuid",
),
email=OutputEmail(
required=True,
description="Send the document by email",
type="str",
),
),
)
return output

View File

@ -1,196 +0,0 @@
import os
import random
import re
import string
from abc import abstractmethod
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import List, Optional
from fastapi import UploadFile
from pydantic import BaseModel
from unidecode import unidecode
from quivr_api.logger import get_logger
from quivr_api.models.settings import SendEmailSettings
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.modules.assistant.ito.utils.pdf_generator import PDFGenerator, PDFModel
from quivr_api.modules.chat.controller.chat.utils import update_user_usage
from quivr_api.modules.upload.controller.upload_routes import upload_file
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.user.service.user_usage import UserUsage
from quivr_api.utils.send_email import send_email
logger = get_logger(__name__)
class ITO(BaseModel):
input: InputAssistant
files: List[UploadFile]
current_user: UserIdentity
user_usage: Optional[UserUsage] = None
user_settings: Optional[dict] = None
def __init__(
self,
input: InputAssistant,
files: List[UploadFile] = None,
current_user: UserIdentity = None,
**kwargs,
):
super().__init__(
input=input,
files=files,
current_user=current_user,
**kwargs,
)
self.user_usage = UserUsage(
id=current_user.id,
email=current_user.email,
)
self.user_settings = self.user_usage.get_user_settings()
self.increase_usage_user()
def increase_usage_user(self):
# Raises an error if the user has consumed all of of his credits
update_user_usage(
usage=self.user_usage,
user_settings=self.user_settings,
cost=self.calculate_pricing(),
)
def calculate_pricing(self):
return 20
def generate_pdf(self, filename: str, title: str, content: str):
pdf_model = PDFModel(title=title, content=content)
pdf = PDFGenerator(pdf_model)
pdf.print_pdf()
pdf.output(filename, "F")
@abstractmethod
async def process_assistant(self):
pass
async def send_output_by_email(
self,
file: UploadFile,
filename: str,
task_name: str,
custom_message: str,
brain_id: str = None,
):
settings = SendEmailSettings()
file = await self.uploadfile_to_file(file)
domain_quivr = os.getenv("QUIVR_DOMAIN", "https://chat.quivr.app/")
with open(file.name, "rb") as f:
mail_from = settings.resend_contact_sales_from
mail_to = self.current_user.email
body = f"""
<div style="text-align: center;">
<img src="https://quivr-cms.s3.eu-west-3.amazonaws.com/logo_quivr_white_7e3c72620f.png" alt="Quivr Logo" style="width: 100px; height: 100px; border-radius: 50%; margin: 0 auto; display: block;">
<p>Quivr's ingestion process has been completed. The processed file is attached.</p>
<p><strong>Task:</strong> {task_name}</p>
<p><strong>Output:</strong> {custom_message}</p>
<br />
</div>
"""
if brain_id:
body += f"<div style='text-align: center;'>You can find the file <a href='{domain_quivr}studio/{brain_id}'>here</a>.</div> <br />"
body += """
<div style="text-align: center;">
<p>Please let us know if you have any questions or need further assistance.</p>
<p> The Quivr Team </p>
</div>
"""
params = {
"from": mail_from,
"to": [mail_to],
"subject": "Quivr Ingestion Processed",
"reply_to": "no-reply@quivr.app",
"html": body,
"attachments": [{"filename": filename, "content": list(f.read())}],
}
logger.info(f"Sending email to {mail_to} with file {filename}")
send_email(params)
async def uploadfile_to_file(self, uploadFile: UploadFile):
# Transform the UploadFile object to a file object with same name and content
tmp_file = NamedTemporaryFile(delete=False)
tmp_file.write(uploadFile.file.read())
tmp_file.flush() # Make sure all data is written to disk
return tmp_file
async def create_and_upload_processed_file(
self, processed_content: str, original_filename: str, file_description: str
) -> dict:
"""Handles creation and uploading of the processed file."""
# remove any special characters from the filename that aren't http safe
new_filename = (
original_filename.split(".")[0]
+ "_"
+ file_description.lower().replace(" ", "_")
+ "_"
+ str(random.randint(1000, 9999))
+ ".pdf"
)
new_filename = unidecode(new_filename)
new_filename = re.sub(
"[^{}0-9a-zA-Z]".format(re.escape(string.punctuation)), "", new_filename
)
self.generate_pdf(
new_filename,
f"{file_description} of {original_filename}",
processed_content,
)
content_io = BytesIO()
with open(new_filename, "rb") as f:
content_io.write(f.read())
content_io.seek(0)
file_to_upload = UploadFile(
filename=new_filename,
file=content_io,
headers={"content-type": "application/pdf"},
)
if self.input.outputs.email.activated:
await self.send_output_by_email(
file_to_upload,
new_filename,
"Summary",
f"{file_description} of {original_filename}",
brain_id=(
self.input.outputs.brain.value
if (
self.input.outputs.brain.activated
and self.input.outputs.brain.value
)
else None
),
)
# Reset to start of file before upload
file_to_upload.file.seek(0)
if self.input.outputs.brain.activated:
await upload_file(
uploadFile=file_to_upload,
brain_id=self.input.outputs.brain.value,
current_user=self.current_user,
chat_id=None,
)
os.remove(new_filename)
return {"message": f"{file_description} generated successfully"}

View File

@ -1,196 +0,0 @@
import tempfile
from typing import List
from fastapi import UploadFile
from langchain.chains import (
MapReduceDocumentsChain,
ReduceDocumentsChain,
StuffDocumentsChain,
)
from langchain.chains.llm import LLMChain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import CharacterTextSplitter
from quivr_api.logger import get_logger
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.modules.assistant.dto.outputs import (
AssistantOutput,
InputFile,
Inputs,
OutputBrain,
OutputEmail,
Outputs,
)
from quivr_api.modules.assistant.ito.ito import ITO
from quivr_api.modules.user.entity.user_identity import UserIdentity
logger = get_logger(__name__)
class SummaryAssistant(ITO):
def __init__(
self,
input: InputAssistant,
files: List[UploadFile] = None,
current_user: UserIdentity = None,
**kwargs,
):
super().__init__(
input=input,
files=files,
current_user=current_user,
**kwargs,
)
def check_input(self):
if not self.files:
raise ValueError("No file was uploaded")
if len(self.files) > 1:
raise ValueError("Only one file can be uploaded")
if not self.input.inputs.files:
raise ValueError("No files key were given in the input")
if len(self.input.inputs.files) > 1:
raise ValueError("Only one file can be uploaded")
if not self.input.inputs.files[0].key == "doc_to_summarize":
raise ValueError("The key of the file should be doc_to_summarize")
if not self.input.inputs.files[0].value:
raise ValueError("No file was uploaded")
# Check if name of file is same as the key
if not self.input.inputs.files[0].value == self.files[0].filename:
raise ValueError(
"The key of the file should be the same as the name of the file"
)
if not (
self.input.outputs.brain.activated or self.input.outputs.email.activated
):
raise ValueError("No output was selected")
return True
async def process_assistant(self):
try:
self.increase_usage_user()
except Exception as e:
logger.error(f"Error increasing usage: {e}")
return {"error": str(e)}
# Create a temporary file with the uploaded file as a temporary file and then pass it to the loader
tmp_file = tempfile.NamedTemporaryFile(delete=False)
# Write the file to the temporary file
tmp_file.write(self.files[0].file.read())
# Now pass the path of the temporary file to the loader
loader = UnstructuredPDFLoader(tmp_file.name)
tmp_file.close()
data = loader.load()
llm = ChatLiteLLM(model="gpt-4o", max_tokens=2000)
map_template = """The following is a document that has been divided into multiple sections:
{docs}
Please carefully analyze each section and identify the following:
1. Main Themes: What are the overarching ideas or topics in this section?
2. Key Points: What are the most important facts, arguments, or ideas presented in this section?
3. Important Information: Are there any crucial details that stand out? This could include data, quotes, specific events, entity, or other relevant information.
4. People: Who are the key individuals mentioned in this section? What roles do they play?
5. Reasoning: What logic or arguments are used to support the key points?
6. Chapters: If the document is divided into chapters, what is the main focus of each chapter?
Remember to consider the language and context of the document. This will help in understanding the nuances and subtleties of the text."""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """The following is a set of summaries for parts of the document:
{docs}
Take these and distill it into a final, consolidated summary of the document. Make sure to include the main themes, key points, and important information such as data, quotes,people and specific events.
Use markdown such as bold, italics, underlined. For example, **bold**, *italics*, and _underlined_ to highlight key points.
Please provide the final summary with sections using bold headers.
Sections should always be Summary and Key Points, but feel free to add more sections as needed.
Always use bold text for the sections headers.
Keep the same language as the documents.
Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
# Run chain
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# Combines and iteratively reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
# This is final chain that is called.
combine_documents_chain=combine_documents_chain,
# If documents exceed context for `StuffDocumentsChain`
collapse_documents_chain=combine_documents_chain,
# The maximum number of tokens to group documents into.
token_max=4000,
)
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
# Map chain
llm_chain=map_chain,
# Reduce chain
reduce_documents_chain=reduce_documents_chain,
# The variable name in the llm_chain to put the documents in
document_variable_name="docs",
# Return the results of the map steps in the output
return_intermediate_steps=False,
)
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=100
)
split_docs = text_splitter.split_documents(data)
content = map_reduce_chain.run(split_docs)
return await self.create_and_upload_processed_file(
content, self.files[0].filename, "Summary"
)
def summary_inputs():
output = AssistantOutput(
name="Summary",
description="Summarize a set of documents",
tags=["new"],
input_description="One document to summarize",
output_description="A summary of the document with key points and main themes",
icon_url="https://quivr-cms.s3.eu-west-3.amazonaws.com/report_94bea8b918.png",
inputs=Inputs(
files=[
InputFile(
key="doc_to_summarize",
allowed_extensions=["pdf"],
required=True,
description="The document to summarize",
)
]
),
outputs=Outputs(
brain=OutputBrain(
required=True,
description="The brain to which upload the document",
type="uuid",
),
email=OutputEmail(
required=True,
description="Send the document by email",
type="str",
),
),
)
return output

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

View File

@ -1 +0,0 @@
from .assistant_interface import AssistantInterface

View File

@ -1,16 +0,0 @@
from abc import ABC, abstractmethod
from typing import List
from quivr_api.modules.assistant.entity.assistant import AssistantEntity
class AssistantInterface(ABC):
@abstractmethod
def get_all_assistants(self) -> List[AssistantEntity]:
"""
Get all the knowledge in a brain
Args:
brain_id (UUID): The id of the brain
"""
pass

View File

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import List
from uuid import UUID
from quivr_api.modules.assistant.dto.inputs import CreateTask
from quivr_api.modules.assistant.entity.task_entity import Task
class TasksInterface(ABC):
@abstractmethod
def create_task(self, task: CreateTask) -> Task:
pass
@abstractmethod
def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
pass
@abstractmethod
def delete_task(self, task_id: UUID, user_id: UUID) -> None:
pass
@abstractmethod
def get_tasks_by_user_id(self, user_id: UUID) -> List[Task]:
pass
@abstractmethod
def update_task(self, task_id: int, task: dict) -> None:
pass
@abstractmethod
def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
pass

View File

@ -0,0 +1,82 @@
from typing import Sequence
from uuid import UUID
from sqlalchemy import exc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from quivr_api.modules.assistant.dto.inputs import CreateTask
from quivr_api.modules.assistant.entity.task_entity import Task
from quivr_api.modules.dependencies import BaseRepository
from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url,
)
class TasksRepository(BaseRepository):
def __init__(self, session: AsyncSession):
super().__init__(session)
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
try:
task_to_create = Task(
assistant_id=task.assistant_id,
pretty_id=task.pretty_id,
user_id=user_id,
settings=task.settings,
)
self.session.add(task_to_create)
await self.session.commit()
except exc.IntegrityError:
await self.session.rollback()
raise Exception()
await self.session.refresh(task_to_create)
return task_to_create
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
response = await self.session.exec(query)
return response.one()
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
query = select(Task).where(Task.user_id == user_id)
response = await self.session.exec(query)
return response.all()
async def delete_task(self, task_id: int, user_id: UUID) -> None:
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
response = await self.session.exec(query)
task = response.one()
if task:
await self.session.delete(task)
await self.session.commit()
else:
raise Exception()
async def update_task(self, task_id: int, task_updates: dict) -> None:
query = select(Task).where(Task.id == task_id)
response = await self.session.exec(query)
task = response.one()
if task:
for key, value in task_updates.items():
setattr(task, key, value)
await self.session.commit()
else:
raise Exception("Task not found")
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
response = await self.session.exec(query)
task = response.one()
path = f"{task.assistant_id}/{task.pretty_id}/output.pdf"
try:
signed_url = generate_file_signed_url(path)
if signed_url and "signedURL" in signed_url:
return signed_url["signedURL"]
else:
raise Exception()
except Exception:
return "error"

View File

@ -1,32 +0,0 @@
from quivr_api.modules.assistant.entity.assistant import AssistantEntity
from quivr_api.modules.assistant.repository.assistant_interface import (
AssistantInterface,
)
from quivr_api.modules.dependencies import get_supabase_client
class Assistant(AssistantInterface):
def __init__(self):
supabase_client = get_supabase_client()
self.db = supabase_client
def get_all_assistants(self):
response = self.db.from_("assistants").select("*").execute()
if response.data:
return response.data
return []
def get_assistant_by_id(self, ingestion_id) -> AssistantEntity:
response = (
self.db.from_("assistants")
.select("*")
.filter("id", "eq", ingestion_id)
.execute()
)
if response.data:
return AssistantEntity(**response.data[0])
return None

View File

@ -0,0 +1,32 @@
from typing import Sequence
from uuid import UUID
from quivr_api.modules.assistant.dto.inputs import CreateTask
from quivr_api.modules.assistant.entity.task_entity import Task
from quivr_api.modules.assistant.repository.tasks import TasksRepository
from quivr_api.modules.dependencies import BaseService
class TasksService(BaseService[TasksRepository]):
repository_cls = TasksRepository
def __init__(self, repository: TasksRepository):
self.repository = repository
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
return await self.repository.create_task(task, user_id)
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
return await self.repository.get_task_by_id(task_id, user_id)
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
return await self.repository.get_tasks_by_user_id(user_id)
async def delete_task(self, task_id: int, user_id: UUID) -> None:
return await self.repository.delete_task(task_id, user_id)
async def update_task(self, task_id: int, task: dict) -> None:
return await self.repository.update_task(task_id, task)
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
return await self.repository.get_download_link_task(task_id, user_id)

View File

@ -1 +1,5 @@
from .brain_routes import brain_router
__all__ = [
"brain_router",
]

View File

@ -2,6 +2,7 @@ from typing import Optional
from uuid import UUID
from pydantic import BaseModel
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import BrainType
from quivr_api.modules.brain.entity.integration_brain import IntegrationType

View File

@ -11,6 +11,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion

View File

@ -4,6 +4,7 @@ from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion

View File

@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion

View File

@ -4,6 +4,7 @@ from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion

View File

@ -7,6 +7,7 @@ from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from quivr_api.modules.brain.integrations.SQL.SQL_connector import SQLConnector
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.brain.repository.integration_brains import IntegrationBrain
@ -85,7 +86,6 @@ class SQLBrain(KnowledgeBrainQA, IntegrationBrain):
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> AsyncIterable:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)

View File

@ -12,13 +12,14 @@ from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict
from quivr_api.logger import get_logger
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from quivr_api.modules.chat.dto.chats import ChatQuestion
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.dependencies import get_service
from typing_extensions import TypedDict
# Post-processing
@ -210,13 +211,11 @@ class SelfBrain(KnowledgeBrainQA):
return question_rewriter
def get_chain(self):
graph = self.create_graph()
return graph
def create_graph(self):
workflow = StateGraph(GraphState)
# Define the nodes

View File

@ -2,5 +2,7 @@
from .brains_interface import BrainsInterface
from .brains_users_interface import BrainsUsersInterface
from .brains_vectors_interface import BrainsVectorsInterface
from .integration_brains_interface import (IntegrationBrainInterface,
IntegrationDescriptionInterface)
from .integration_brains_interface import (
IntegrationBrainInterface,
IntegrationDescriptionInterface,
)

View File

@ -38,7 +38,6 @@ class IntegrationBrainInterface(ABC):
class IntegrationDescriptionInterface(ABC):
@abstractmethod
def get_integration_description(
self, integration_id: UUID

View File

@ -2,6 +2,7 @@ from typing import List, Optional, Union
from uuid import UUID
from fastapi import Depends, HTTPException, status
from quivr_api.middlewares.auth.auth_bearer import get_current_user
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
from quivr_api.modules.brain.service.brain_service import BrainService
@ -13,7 +14,7 @@ brain_service = BrainService()
def has_brain_authorization(
required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner
required_roles: Optional[Union[RoleEnum, List[RoleEnum]]] = RoleEnum.Owner,
):
"""
Decorator to check if the user has the required role(s) for the brain

View File

@ -2,6 +2,7 @@ from typing import List
from uuid import UUID
from fastapi import HTTPException
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import (
BrainEntity,

View File

@ -1,6 +1,7 @@
from typing import List, Tuple
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput

View File

@ -15,5 +15,7 @@ def get_prompt_to_use_id(
return (
prompt_id
if prompt_id
else brain_service.get_brain_prompt_id(brain_id) if brain_id else None
else brain_service.get_brain_prompt_id(brain_id)
if brain_id
else None
)

View File

@ -1,2 +0,0 @@
from fastapi import HTTPException
from quivr_api.modules.brain.dto.inputs import CreateBrainProperties

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
from uuid import UUID
from pydantic import BaseModel
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.notification.entity.notification import Notification

View File

@ -2,12 +2,12 @@ from datetime import datetime
from typing import List
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import JSON, TIMESTAMP, Column, Field, Relationship, SQLModel, text
from sqlmodel import UUID as PGUUID
from quivr_api.modules.brain.entity.brain_entity import Brain
from quivr_api.modules.user.entity.user_identity import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import JSON, TIMESTAMP
from sqlmodel import UUID as PGUUID
from sqlmodel import Column, Field, Relationship, SQLModel, text
class Chat(SQLModel, table=True):

View File

@ -1 +1 @@
from .knowledge_routes import knowledge_router
from .knowledge_routes import knowledge_router

View File

@ -85,3 +85,5 @@ class SupabaseS3Storage(StorageInterface):
return response
except Exception as e:
logger.error(e)
raise e

View File

@ -1,6 +1,7 @@
from typing import Annotated, List
from fastapi import APIRouter, Depends
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.dependencies import get_service

View File

@ -1 +1 @@
from .inputs import NotificationUpdatableProperties
from .inputs import NotificationUpdatableProperties

View File

@ -1,6 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, Depends
from quivr_api.middlewares.auth import AuthBearer
from quivr_api.modules.prompt.entity.prompt import (
CreatePromptProperties,

View File

@ -1 +1,7 @@
from .prompt import Prompt, PromptStatusEnum, CreatePromptProperties, PromptUpdatableProperties, DeletePromptResponse
from .prompt import (
CreatePromptProperties,
DeletePromptResponse,
Prompt,
PromptStatusEnum,
PromptUpdatableProperties,
)

View File

@ -4,6 +4,7 @@ import requests
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from msal import ConfidentialClientApplication
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput

View File

@ -3,6 +3,7 @@ import os
import requests
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput

View File

@ -50,4 +50,4 @@ successfullConnectionPage = """
</div>
</body>
</html>
"""
"""

View File

@ -3,6 +3,7 @@ from typing import Any, List, Literal, Union
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
from quivr_api.modules.sync.entity.sync_models import NotionSyncFile

View File

@ -2,20 +2,20 @@ from datetime import datetime, timedelta
from typing import List, Sequence
from uuid import UUID
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import (BaseRepository, get_supabase_client)
from quivr_api.modules.notification.service.notification_service import \
NotificationService
from quivr_api.modules.sync.dto.inputs import (SyncsActiveInput,
SyncsActiveUpdateInput)
from quivr_api.modules.sync.entity.sync_models import (NotionSyncFile,
SyncsActive)
from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
from quivr_api.modules.notification.service.notification_service import (
NotificationService,
)
from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
from quivr_api.modules.sync.entity.sync_models import NotionSyncFile, SyncsActive
from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface
notification_service = NotificationService()
logger = get_logger(__name__)

View File

@ -1,4 +1,4 @@
from .email_sender import EmailSenderTool
from .image_generator import ImageGeneratorTool
from .web_search import WebSearchTool
from .url_reader import URLReaderTool
from .email_sender import EmailSenderTool
from .web_search import WebSearchTool

View File

@ -11,6 +11,7 @@ from langchain.pydantic_v1 import Field as FieldV1
from langchain_community.document_loaders import PlaywrightURLLoader
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from quivr_api.logger import get_logger
logger = get_logger(__name__)
@ -29,7 +30,6 @@ class URLReaderTool(BaseTool):
def _run(
self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> Dict:
loader = PlaywrightURLLoader(urls=[url], remove_selectors=["header", "footer"])
data = loader.load()

View File

@ -10,6 +10,7 @@ from langchain.pydantic_v1 import BaseModel as BaseModelV1
from langchain.pydantic_v1 import Field as FieldV1
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from quivr_api.logger import get_logger
logger = get_logger(__name__)

View File

@ -1,8 +1,8 @@
import os
from multiprocessing import get_logger
from quivr_api.modules.dependencies import get_supabase_client
from supabase.client import Client
import os
logger = get_logger()
@ -10,6 +10,7 @@ SIGNED_URL_EXPIRATION_PERIOD_IN_SECONDS = 3600
EXTERNAL_SUPABASE_URL = os.getenv("EXTERNAL_SUPABASE_URL", None)
SUPABASE_URL = os.getenv("SUPABASE_URL", None)
def generate_file_signed_url(path):
supabase_client: Client = get_supabase_client()
@ -25,7 +26,9 @@ def generate_file_signed_url(path):
logger.info("RESPONSE SIGNED URL", response)
# Replace in the response the supabase url by the external supabase url in the object signedURL
if EXTERNAL_SUPABASE_URL and SUPABASE_URL:
response["signedURL"] = response["signedURL"].replace(SUPABASE_URL, EXTERNAL_SUPABASE_URL)
response["signedURL"] = response["signedURL"].replace(
SUPABASE_URL, EXTERNAL_SUPABASE_URL
)
return response
except Exception as e:
logger.error(e)

View File

@ -1,8 +1,7 @@
from multiprocessing import get_logger
from supabase.client import Client
from quivr_api.modules.dependencies import get_supabase_client
from supabase.client import Client
logger = get_logger()

View File

@ -10,4 +10,3 @@ class UserUpdatableProperties(BaseModel):
onboarded: Optional[bool] = None
company_size: Optional[str] = None
usage_purpose: Optional[str] = None

View File

@ -1 +1 @@
from .user_service import UserService
from .user_service import UserService

View File

@ -3,12 +3,11 @@ from uuid import UUID
from pgvector.sqlalchemy import Vector as PGVector
from pydantic import BaseModel
from quivr_api.models.settings import settings
from sqlalchemy import Column
from sqlmodel import JSON, Column, Field, SQLModel, text
from sqlmodel import UUID as PGUUID
from quivr_api.models.settings import settings
class Vector(SQLModel, table=True):
__tablename__ = "vectors" # type: ignore

View File

@ -3,6 +3,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
from quivr_api.models.brains_subscription_invitations import BrainSubscription
@ -253,7 +254,7 @@ async def accept_invitation(
is_default_brain=False,
)
shared_brain = brain_service.get_brain_by_id(brain_id)
except Exception as e:
logger.error(f"Error adding user to brain: {e}")
raise HTTPException(status_code=400, detail=f"Error adding user to brain: {e}")

View File

@ -1,6 +1,7 @@
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from quivr_api.logger import get_logger
logger = get_logger(__name__)

View File

@ -1,38 +1,46 @@
from unittest.mock import patch, MagicMock
from quivr_api.modules.dependencies import get_embedding_client
from unittest.mock import MagicMock, patch
from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from quivr_api.modules.dependencies import get_embedding_client
def test_ollama_embedding():
with patch("quivr_api.modules.dependencies.settings") as mock_settings:
mock_settings.ollama_api_base_url = "http://ollama.example.com"
mock_settings.azure_openai_embeddings_url = None
embedding_client = get_embedding_client()
assert isinstance(embedding_client, OllamaEmbeddings)
assert embedding_client.base_url == "http://ollama.example.com"
def test_azure_embedding():
with patch("quivr_api.modules.dependencies.settings") as mock_settings:
mock_settings.ollama_api_base_url = None
mock_settings.azure_openai_embeddings_url = "https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15"
embedding_client = get_embedding_client()
assert isinstance(embedding_client, AzureOpenAIEmbeddings)
assert embedding_client.azure_endpoint == "https://quivr-test.openai.azure.com"
def test_openai_embedding():
with patch("quivr_api.modules.dependencies.settings") as mock_settings, \
patch("quivr_api.modules.dependencies.OpenAIEmbeddings") as mock_openai_embeddings:
with (
patch("quivr_api.modules.dependencies.settings") as mock_settings,
patch(
"quivr_api.modules.dependencies.OpenAIEmbeddings"
) as mock_openai_embeddings,
):
mock_settings.ollama_api_base_url = None
mock_settings.azure_openai_embeddings_url = None
# Create a mock instance for OpenAIEmbeddings
mock_openai_instance = MagicMock()
mock_openai_embeddings.return_value = mock_openai_instance
embedding_client = get_embedding_client()
assert embedding_client == mock_openai_instance

View File

@ -1,12 +1,11 @@
from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.language_models import FakeListChatModel
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from quivr_core import Brain
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm.llm_endpoint import LLMEndpoint
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
if __name__ == "__main__":
brain = Brain.from_files(

View File

@ -3,6 +3,7 @@ from typing import Any, Generator, Tuple
from uuid import UUID, uuid4
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage
@ -54,7 +55,7 @@ class ChatHistory:
"""
# Reverse the chat_history, newest first
it = iter(self.get_chat_history(newest_first=True))
for ai_message, human_message in zip(it, it):
for ai_message, human_message in zip(it, it, strict=False):
assert isinstance(
human_message.msg, HumanMessage
), f"msg {human_message} is not HumanMessage"

View File

@ -55,7 +55,7 @@ class ChatLLM:
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
return filtered_chat_history
def build_chain(self):

View File

@ -14,19 +14,20 @@ logger = logging.getLogger("quivr_core")
class MegaparseProcessor(ProcessorBase):
'''
"""
Megaparse processor for PDF files.
It can be used to parse PDF files and split them into chunks.
It comes from the megaparse library.
## Installation
```bash
pip install megaparse
```
'''
"""
supported_extensions = [FileExtension.pdf]
def __init__(

View File

@ -37,6 +37,7 @@ def model_supports_function_calling(model_name: str):
]
return model_name in models_supporting_function_calls
def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
@ -125,7 +126,11 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
)
if model_supports_function_calling(model_name):
if 'tool_calls' in raw_response["answer"] and raw_response["answer"].tool_calls and "citations" in raw_response["answer"].tool_calls[-1]["args"]:
if (
"tool_calls" in raw_response["answer"]
and raw_response["answer"].tool_calls
and "citations" in raw_response["answer"].tool_calls[-1]["args"]
):
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
metadata.citations = citations
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
@ -147,7 +152,7 @@ def combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs))):
for doc, index in zip(docs, range(len(docs)), strict=False):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)

View File

@ -5,7 +5,6 @@ from uuid import uuid4
from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.vectorstores import InMemoryVectorStore
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RAGConfig
from quivr_core.llm import LLMEndpoint

View File

@ -2,7 +2,6 @@ from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.processor.implementations.default import MarkdownProcessor

View File

@ -2,7 +2,6 @@ from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.processor.implementations.default import DOCXProcessor

View File

@ -2,7 +2,6 @@ from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.processor.implementations.default import EpubProcessor

View File

@ -2,7 +2,6 @@ from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.processor.implementations.default import ODTProcessor

View File

@ -2,7 +2,6 @@ from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.processor.implementations.default import UnstructuredPDFProcessor

View File

@ -1,5 +1,4 @@
import pytest
from quivr_core.files.file import FileExtension
from quivr_core.processor.processor_base import ProcessorBase
@ -7,7 +6,6 @@ from quivr_core.processor.processor_base import ProcessorBase
@pytest.mark.base
def test___build_processor():
from langchain_community.document_loaders.base import BaseLoader
from quivr_core.processor.implementations.default import _build_processor
cls = _build_processor("TestCLS", BaseLoader, [FileExtension.txt])

View File

@ -1,6 +1,5 @@
import pytest
from langchain_core.documents import Document
from quivr_core.files.file import FileExtension
from quivr_core.processor.implementations.simple_txt_processor import (
SimpleTxtProcessor,

View File

@ -1,5 +1,4 @@
import pytest
from quivr_core.processor.implementations.tika_processor import TikaProcessor
# TODO: TIKA server should be set

View File

@ -3,7 +3,6 @@ from uuid import uuid4
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.chat import ChatHistory

View File

@ -1,5 +1,4 @@
import pytest
from quivr_core import ChatLLM

View File

@ -3,7 +3,6 @@ import os
import pytest
from langchain_core.language_models import FakeListChatModel
from pydantic.v1.error_wrappers import ValidationError
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm import LLMEndpoint

View File

@ -3,7 +3,6 @@ from uuid import uuid4
import pytest
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from quivr_core.utils import (
get_prev_message_str,
model_supports_function_calling,

View File

@ -173,6 +173,7 @@ debugpy==1.8.5
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via fpdf2
# via langchain-anthropic
# via nbconvert
deprecated==1.2.14
@ -238,10 +239,11 @@ flatbuffers==24.3.25
flower==2.0.1
# via quivr-worker
fonttools==4.53.1
# via fpdf2
# via matplotlib
# via pdf2docx
fpdf==1.7.2
# via quivr-api
fpdf2==2.7.9
# via quivr-worker
frozenlist==1.4.1
# via aiohttp
# via aiosignal
@ -747,6 +749,7 @@ pgvector==0.3.2
pikepdf==9.1.1
# via unstructured
pillow==10.2.0
# via fpdf2
# via layoutparser
# via llama-index-core
# via matplotlib

View File

@ -150,6 +150,7 @@ debugpy==1.8.5
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via fpdf2
# via langchain-anthropic
# via nbconvert
deprecated==1.2.14
@ -200,10 +201,11 @@ flatbuffers==24.3.25
flower==2.0.1
# via quivr-worker
fonttools==4.53.1
# via fpdf2
# via matplotlib
# via pdf2docx
fpdf==1.7.2
# via quivr-api
fpdf2==2.7.9
# via quivr-worker
frozenlist==1.4.1
# via aiohttp
# via aiosignal
@ -650,6 +652,7 @@ pgvector==0.3.2
pikepdf==9.1.1
# via unstructured
pillow==10.2.0
# via fpdf2
# via layoutparser
# via llama-index-core
# via matplotlib

View File

@ -0,0 +1,79 @@
create table "public"."tasks" (
"id" bigint generated by default as identity not null,
"pretty_id" text,
"user_id" uuid not null default auth.uid(),
"status" text,
"creation_time" timestamp with time zone default (now() AT TIME ZONE 'utc'::text),
"answer_raw" jsonb,
"answer_pretty" text
);
alter table "public"."tasks" enable row level security;
CREATE UNIQUE INDEX tasks_pkey ON public.tasks USING btree (id);
alter table "public"."tasks" add constraint "tasks_pkey" PRIMARY KEY using index "tasks_pkey";
alter table "public"."tasks" add constraint "tasks_user_id_fkey" FOREIGN KEY (user_id) REFERENCES auth.users(id) ON UPDATE CASCADE ON DELETE CASCADE not valid;
alter table "public"."tasks" validate constraint "tasks_user_id_fkey";
grant delete on table "public"."tasks" to "anon";
grant insert on table "public"."tasks" to "anon";
grant references on table "public"."tasks" to "anon";
grant select on table "public"."tasks" to "anon";
grant trigger on table "public"."tasks" to "anon";
grant truncate on table "public"."tasks" to "anon";
grant update on table "public"."tasks" to "anon";
grant delete on table "public"."tasks" to "authenticated";
grant insert on table "public"."tasks" to "authenticated";
grant references on table "public"."tasks" to "authenticated";
grant select on table "public"."tasks" to "authenticated";
grant trigger on table "public"."tasks" to "authenticated";
grant truncate on table "public"."tasks" to "authenticated";
grant update on table "public"."tasks" to "authenticated";
grant delete on table "public"."tasks" to "service_role";
grant insert on table "public"."tasks" to "service_role";
grant references on table "public"."tasks" to "service_role";
grant select on table "public"."tasks" to "service_role";
grant trigger on table "public"."tasks" to "service_role";
grant truncate on table "public"."tasks" to "service_role";
grant update on table "public"."tasks" to "service_role";
create policy "allow_user_all_tasks"
on "public"."tasks"
as permissive
for all
to public
using ((user_id = ( SELECT auth.uid() AS uid)));
create policy "tasks"
on "public"."tasks"
as permissive
for all
to service_role;

View File

@ -0,0 +1,11 @@
alter table "public"."tasks" drop column "answer_pretty";
alter table "public"."tasks" drop column "answer_raw";
alter table "public"."tasks" add column "answer" text;
alter table "public"."tasks" add column "assistant_id" bigint not null;
alter table "public"."tasks" add column "settings" jsonb;

View File

@ -13,10 +13,11 @@ dependencies = [
"playwright>=1.0.0",
"openai>=1.0.0",
"flower>=2.0.1",
"torch==2.4.0; platform_machine != 'x86_64'" ,
"torch==2.4.0; platform_machine != 'x86_64'",
"torch==2.4.0+cpu; platform_machine == 'x86_64'",
"torchvision==0.19.0; platform_machine != 'x86_64'",
"torchvision==0.19.0+cpu; platform_machine == 'x86_64'",
"fpdf2>=2.7.9",
]
readme = "README.md"
requires-python = ">= 3.11"

View File

@ -0,0 +1,40 @@
import os
from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.upload.service.upload_file import (
upload_file_storage,
)
from quivr_worker.utils.pdf_generator.pdf_generator import PDFGenerator, PDFModel
async def process_assistant(
assistant_id: str,
notification_uuid: str,
task_id: int,
tasks_service: TasksService,
user_id: str,
):
task = await tasks_service.get_task_by_id(task_id, user_id) # type: ignore
await tasks_service.update_task(task_id, {"status": "in_progress"})
print(task)
task_result = {"status": "completed", "answer": "#### Assistant answer"}
output_dir = f"{assistant_id}/{notification_uuid}"
os.makedirs(output_dir, exist_ok=True)
output_path = f"{output_dir}/output.pdf"
generated_pdf = PDFGenerator(PDFModel(title="Test", content="Test"))
generated_pdf.print_pdf()
generated_pdf.output(output_path)
with open(output_path, "rb") as file:
await upload_file_storage(file, output_path)
# Now delete the file
os.remove(output_path)
await tasks_service.update_task(task_id, task_result)

View File

@ -8,6 +8,8 @@ from dotenv import load_dotenv
from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.models.settings import settings
from quivr_api.modules.assistant.repository.tasks import TasksRepository
from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
from quivr_api.modules.brain.service.brain_service import BrainService
@ -29,6 +31,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import Session, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_worker.assistants.assistants import process_assistant
from quivr_worker.check_premium import check_is_premium
from quivr_worker.process.process_s3_file import process_uploaded_file
from quivr_worker.process.process_url import process_url_func
@ -39,7 +42,7 @@ from quivr_worker.syncs.process_active_syncs import (
process_sync,
)
from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async
from quivr_worker.utils import _patch_json
from quivr_worker.utils.utils import _patch_json
load_dotenv()
@ -91,6 +94,63 @@ def init_worker(**kwargs):
)
@celery.task(
retries=3,
default_retry_delay=1,
name="process_assistant_task",
autoretry_for=(Exception,),
)
def process_assistant_task(
assistant_id: str,
notification_uuid: str,
task_id: int,
user_id: str,
):
logger.info(
f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}"
)
print("process_assistant_task")
loop = asyncio.get_event_loop()
loop.run_until_complete(
aprocess_assistant_task(
assistant_id,
notification_uuid,
task_id,
user_id,
)
)
async def aprocess_assistant_task(
assistant_id: str,
notification_uuid: str,
task_id: int,
user_id: str,
):
async with AsyncSession(async_engine) as async_session:
try:
await async_session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
tasks_repository = TasksRepository(async_session)
tasks_service = TasksService(tasks_repository)
await process_assistant(
assistant_id,
notification_uuid,
task_id,
tasks_service,
user_id,
)
except Exception as e:
await async_session.rollback()
raise e
finally:
await async_session.close()
@celery.task(
retries=3,
default_retry_delay=1,
@ -111,10 +171,6 @@ def process_file_task(
if async_engine is None:
init_worker()
logger.info(
f"Task process_file started for file_name={file_name}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}"
)
loop = asyncio.get_event_loop()
loop.run_until_complete(
aprocess_file_task(

View File

@ -9,7 +9,7 @@ from uuid import UUID
from quivr_api.logger import get_logger
from quivr_core.files.file import FileExtension, QuivrFile
from quivr_worker.utils import get_tmp_name
from quivr_worker.utils.utils import get_tmp_name
logger = get_logger("celery_worker")

Some files were not shown because too many files have changed in this diff Show More