mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat(quivr-core): beginning (#3388)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
4e96ad86cb
commit
7acb52a963
106
.env.example
106
.env.example
@ -1,106 +0,0 @@
|
||||
#### QUIVR Configuration
|
||||
# This file is used to configure the Quivr stack. It is used by the `docker-compose.yml` file to configure the stack.
|
||||
|
||||
# API KEYS
|
||||
# OPENAI. Update this to use your API key. To skip OpenAI integration use a fake key, for example: tk-aabbccddAABBCCDDEeFfGgHhIiJKLmnopjklMNOPqQqQqQqQ
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
# ANTHROPIC_API_KEY=your-anthropic-api-key
|
||||
# MISTRAL_API_KEY=your-mistral-api-key
|
||||
# GROQ_API_KEY=your-groq-api-key
|
||||
|
||||
COHERE_API_KEY=your-cohere-api-key
|
||||
# JINA_API_KEY=your-jina-api-key
|
||||
|
||||
# UNSTRUCTURED_API_KEY=your-unstructured-api-key
|
||||
# UNSTRUCTURED_API_URL=https://api.unstructured.io/general/v0/general
|
||||
|
||||
# LLAMA_PARSE_API_KEY=your-llamaparse-api-key
|
||||
|
||||
# Configuration files path
|
||||
BRAIN_CONFIG_PATH=config/retrieval_config_workflow.yaml
|
||||
CHAT_LLM_CONFIG_PATH=config/chat_llm_config.yaml
|
||||
|
||||
# LangSmith
|
||||
# LANGCHAIN_TRACING_V2=true
|
||||
# LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
|
||||
# LANGCHAIN_API_KEY=your-langchain-api-key
|
||||
# LANGCHAIN_PROJECT=your-langchain-project-name
|
||||
|
||||
# LOCAL
|
||||
# OLLAMA_API_BASE_URL=http://host.docker.internal:11434 # Uncomment to activate ollama. This is the local url for the ollama api
|
||||
|
||||
########
|
||||
# FRONTEND
|
||||
########
|
||||
|
||||
NEXT_PUBLIC_ENV=local
|
||||
NEXT_PUBLIC_BACKEND_URL=http://localhost:5050
|
||||
NEXT_PUBLIC_SUPABASE_URL=http://localhost:54321
|
||||
NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6ImFub24iLCJleHAiOjE5ODM4MTI5OTZ9.CRXP1A7WOeoJeXxjNni43kdQwgnWNReilDMblYTn_I0
|
||||
NEXT_PUBLIC_CMS_URL=https://cms.quivr.app
|
||||
NEXT_PUBLIC_FRONTEND_URL=http://localhost:*
|
||||
NEXT_PUBLIC_AUTH_MODES=password
|
||||
NEXT_PUBLIC_SHOW_TOKENS=false
|
||||
#NEXT_PUBLIC_PROJECT_NAME=<project-name>
|
||||
|
||||
|
||||
########
|
||||
# BACKEND
|
||||
########
|
||||
|
||||
LOG_LEVEL=INFO
|
||||
SUPABASE_URL=http://host.docker.internal:54321
|
||||
EXTERNAL_SUPABASE_URL=http://localhost:54321
|
||||
SUPABASE_SERVICE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImV4cCI6MTk4MzgxMjk5Nn0.EGIM96RAZx35lJzdJsyH-qQwv8Hdp7fsn3W0YpN81IU
|
||||
PG_DATABASE_URL=postgresql://postgres:postgres@host.docker.internal:54322/postgres
|
||||
PG_DATABASE_ASYNC_URL=postgresql+asyncpg://postgres:postgres@host.docker.internal:54322/postgres
|
||||
JWT_SECRET_KEY=super-secret-jwt-token-with-at-least-32-characters-long
|
||||
AUTHENTICATE=true
|
||||
TELEMETRY_ENABLED=true
|
||||
CELERY_BROKER_URL=redis://redis:6379/0
|
||||
CELEBRY_BROKER_QUEUE_NAME=quivr-preview.fifo
|
||||
QUIVR_DOMAIN=http://localhost:3000/
|
||||
BACKEND_URL=http://localhost:5050
|
||||
EMBEDDING_DIM=1536
|
||||
DEACTIVATE_STRIPE=true
|
||||
|
||||
|
||||
# PARSEABLE LOGGING
|
||||
USE_PARSEABLE=False
|
||||
PARSEABLE_STREAM_NAME=quivr-api
|
||||
PARSEABLE_URL=<change-me>
|
||||
PARSEABLE_AUTH=<change-me>
|
||||
|
||||
#RESEND
|
||||
RESEND_API_KEY=<change-me>
|
||||
RESEND_EMAIL_ADDRESS=onboarding@resend.dev
|
||||
RESEND_CONTACT_SALES_FROM=contact_sales@resend.dev
|
||||
RESEND_CONTACT_SALES_TO=<change-me>
|
||||
|
||||
# SMTP
|
||||
QUIVR_SMTP_SERVER=smtp.example.com
|
||||
QUIVR_SMTP_PORT=587
|
||||
QUIVR_SMTP_USERNAME=username
|
||||
QUIVR_SMTP_PASSWORD=password
|
||||
|
||||
CRAWL_DEPTH=1
|
||||
|
||||
PREMIUM_MAX_BRAIN_NUMBER=30
|
||||
PREMIUM_MAX_BRAIN_SIZE=10000000
|
||||
PREMIUM_DAILY_CHAT_CREDIT=100
|
||||
|
||||
# BRAVE SEARCH API KEY
|
||||
BRAVE_SEARCH_API_KEY=CHANGE_ME
|
||||
|
||||
|
||||
# GOOGLE DRIVE
|
||||
GOOGLE_CLIENT_ID=your-client-id
|
||||
GOOGLE_CLIENT_SECRET=your-client-secret
|
||||
GOOGLE_PROJECT_ID=your-project-id
|
||||
GOOGLE_AUTH_URI=https://accounts.google.com/o/oauth2/auth
|
||||
GOOGLE_TOKEN_URI=https://oauth2.googleapis.com/token
|
||||
GOOGLE_AUTH_PROVIDER_CERT_URL=https://www.googleapis.com/oauth2/v1/certs
|
||||
GOOGLE_REDIRECT_URI=http://localhost
|
||||
|
||||
# SHAREPOINT
|
||||
SHAREPOINT_CLIENT_ID=your-client-id
|
4
.flake8
4
.flake8
@ -1,4 +0,0 @@
|
||||
[flake8]
|
||||
; Minimal configuration for Flake8 to work with Black.
|
||||
max-line-length = 100
|
||||
ignore = E101,E111,E112,E221,E222,E501,E711,E712,W503,W504,F401
|
103
.github/workflows/aws-strapi.yml
vendored
103
.github/workflows/aws-strapi.yml
vendored
@ -1,103 +0,0 @@
|
||||
name: Deploy Strapi
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
paths:
|
||||
- "cms/**"
|
||||
|
||||
env:
|
||||
AWS_REGION: eu-west-3
|
||||
ECR_REPOSITORY: quivr-strapi
|
||||
ECR_REGISTRY: 253053805092.dkr.ecr.eu-west-3.amazonaws.com
|
||||
ECS_CLUSTER: quivr
|
||||
|
||||
jobs:
|
||||
build_and_push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
environment: production
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@5fd3084fc36e372ff1fff382a39b10d03659f355 # v2
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@2fc7aceee09e9e4a7105c0d060c656fad0b4f63d # v1
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@465a07811f14bebb1938fbed4728c6a1ff8901fc # v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@885d1462b80bc1c1c7f0b00334ad271f09369c55 # v2
|
||||
|
||||
- name: Create Docker Cacha Storage Backend
|
||||
run: |
|
||||
docker buildx create --use --driver=docker-container
|
||||
- name: See the file in the runner
|
||||
run: |
|
||||
ls -la
|
||||
- name: Build, tag, and push image to Amazon ECR
|
||||
id: build-image
|
||||
uses: docker/build-push-action@0a97817b6ade9f46837855d676c4cca3a2471fc9 # v4
|
||||
env:
|
||||
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
||||
IMAGE_TAG: ${{ github.sha }}
|
||||
with:
|
||||
context: ./cms/quivr/
|
||||
push: true
|
||||
tags: ${{ env.ECR_REGISTRY }}/${{ env.ECR_REPOSITORY }}:${{ env.IMAGE_TAG }}, ${{ env.ECR_REGISTRY }}/${{ env.ECR_REPOSITORY }}:latest, ghcr.io/quivrhq/quivr:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
deploy:
|
||||
needs: build_and_push
|
||||
runs-on: ubuntu-latest
|
||||
environment: production
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: "strapi"
|
||||
service: "strapi"
|
||||
task_definition: ".aws/task_definition_strapi.json"
|
||||
container: "strapi"
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@5fd3084fc36e372ff1fff382a39b10d03659f355 # v2
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Fill in the new image ID in the Amazon ECS task definition for ${{ matrix.name }}
|
||||
id: task-def
|
||||
uses: aws-actions/amazon-ecs-render-task-definition@4225e0b507142a2e432b018bc3ccb728559b437a # v1
|
||||
with:
|
||||
task-definition: ${{ matrix.task_definition }}
|
||||
container-name: ${{ matrix.container }}
|
||||
image: ${{env.ECR_REGISTRY}}/${{ env.ECR_REPOSITORY }}:${{ github.sha }}
|
||||
|
||||
- name: Deploy Amazon ECS task definition for ${{ matrix.name }}
|
||||
uses: aws-actions/amazon-ecs-deploy-task-definition@df9643053eda01f169e64a0e60233aacca83799a # v1
|
||||
with:
|
||||
task-definition: ${{ steps.task-def.outputs.task-definition }}
|
||||
service: ${{ matrix.service }}
|
||||
cluster: ${{ env.ECS_CLUSTER }}
|
||||
wait-for-service-stability: true
|
8
.github/workflows/backend-core-tests.yml
vendored
8
.github/workflows/backend-core-tests.yml
vendored
@ -3,10 +3,10 @@ name: Run Tests with Tika Server
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "backend/core/**"
|
||||
- "core/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- "backend/core/**"
|
||||
- "core/**"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@ -30,7 +30,7 @@ jobs:
|
||||
working-directory: backend
|
||||
- name: 🔄 Sync dependencies
|
||||
run: |
|
||||
cd backend
|
||||
cd core
|
||||
UV_INDEX_STRATEGY=unsafe-first-match rye sync --no-lock
|
||||
|
||||
- name: Run tests
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
|
||||
cd backend
|
||||
cd core
|
||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
rye test -p quivr-core
|
||||
|
70
.github/workflows/backend-tests.yml
vendored
70
.github/workflows/backend-tests.yml
vendored
@ -1,70 +0,0 @@
|
||||
name: Run Tests API and Worker
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "backend/**"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
project: [quivr-api, quivr-worker]
|
||||
steps:
|
||||
- name: 👀 Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: 🔨 Install the latest version of rye
|
||||
uses: eifinger/setup-rye@v4
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: backend
|
||||
|
||||
- name: 🔄 Sync dependencies
|
||||
run: |
|
||||
cd backend
|
||||
UV_INDEX_STRATEGY=unsafe-first-match rye sync --no-lock
|
||||
|
||||
- name: 🚤 Install Supabase CLI
|
||||
run: |
|
||||
ARCHITECTURE=$(uname -m)
|
||||
if [ "$ARCHITECTURE" = "x86_64" ]; then
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_amd64.deb
|
||||
sudo dpkg -i supabase_1.163.6_linux_amd64.deb
|
||||
elif [ "$ARCHITECTURE" = "aarch64" ]; then
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_arm64.deb
|
||||
sudo dpkg -i supabase_1.163.6_linux_arm64.deb
|
||||
fi
|
||||
|
||||
- name: 😭 Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
|
||||
|
||||
- name: Install dependencies and run tests
|
||||
env:
|
||||
OPENAI_API_KEY: this-is-a-fake-openai-api-key
|
||||
SUPABASE_URL: http://localhost:54321
|
||||
SUPABASE_SERVICE_KEY: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImV4cCI6MTk4MzgxMjk5Nn0.EGIM96RAZx35lJzdJsyH-qQwv8Hdp7fsn3W0YpN81IU
|
||||
PG_DATABASE_URL: postgresql://postgres:postgres@localhost:54322/postgres
|
||||
PG_DATABASE_ASYNC_URL: postgresql+asyncpg://postgres:postgres@localhost:54322/postgres
|
||||
ANTHROPIC_API_KEY: null
|
||||
JWT_SECRET_KEY: super-secret-jwt-token-with-at-least-32-characters-long
|
||||
AUTHENTICATE: true
|
||||
TELEMETRY_ENABLED: true
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELEBRY_BROKER_QUEUE_NAME: quivr-preview.fifo
|
||||
QUIVR_DOMAIN: http://localhost:3000/
|
||||
BACKEND_URL: http://localhost:5050
|
||||
EMBEDDING_DIM: 1536
|
||||
run: |
|
||||
cd backend
|
||||
sed -i 's/enabled = true/enabled = false/g' supabase/config.toml
|
||||
sed -i '/\[db\]/,/\[.*\]/s/enabled = false/enabled = true/' supabase/config.toml
|
||||
sed -i '/\[storage\]/,/\[.*\]/s/enabled = false/enabled = true/' supabase/config.toml
|
||||
supabase start
|
||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
rye test -p ${{ matrix.project }}
|
29
.github/workflows/porter_stack_cdp-front.yml
vendored
29
.github/workflows/porter_stack_cdp-front.yml
vendored
@ -1,29 +0,0 @@
|
||||
"on":
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
name: Deploy to cdp-front
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.porter.run
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_REPO_NAME: ${{ github.event.repository.name }}
|
||||
PORTER_STACK_NAME: cdp-front
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
29
.github/workflows/porter_stack_cdp.yml
vendored
29
.github/workflows/porter_stack_cdp.yml
vendored
@ -1,29 +0,0 @@
|
||||
"on":
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
name: Deploy to cdp
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.porter.run
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_REPO_NAME: ${{ github.event.repository.name }}
|
||||
PORTER_STACK_NAME: cdp
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
@ -1,28 +0,0 @@
|
||||
"on":
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
name: Deploy to preview-frontend
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.getporter.dev
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_STACK_NAME: preview-frontend
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
28
.github/workflows/porter_stack_preview.yml
vendored
28
.github/workflows/porter_stack_preview.yml
vendored
@ -1,28 +0,0 @@
|
||||
"on":
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
name: Deploy to preview
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.getporter.dev
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_STACK_NAME: preview
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
28
.github/workflows/porter_stack_production.yml
vendored
28
.github/workflows/porter_stack_production.yml
vendored
@ -1,28 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
name: Deploy to production
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.getporter.dev
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_STACK_NAME: production
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
@ -1,29 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
name: Deploy to raise-frontend
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.getporter.dev
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_REPO_NAME: ${{ github.event.repository.name }}
|
||||
PORTER_STACK_NAME: raise-frontend
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
29
.github/workflows/porter_stack_raise.yml
vendored
29
.github/workflows/porter_stack_raise.yml
vendored
@ -1,29 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
name: Deploy to raise
|
||||
jobs:
|
||||
porter-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set Github tag
|
||||
id: vars
|
||||
run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
- name: Setup porter
|
||||
uses: porter-dev/setup-porter@v0.1.0
|
||||
- name: Deploy stack
|
||||
timeout-minutes: 30
|
||||
run: exec porter apply
|
||||
env:
|
||||
PORTER_CLUSTER: "3877"
|
||||
PORTER_DEPLOYMENT_TARGET_ID: cd21d246-86df-49e0-ba0a-78e2802572e7
|
||||
PORTER_HOST: https://dashboard.getporter.dev
|
||||
PORTER_PR_NUMBER: ${{ github.event.number }}
|
||||
PORTER_PROJECT: "10983"
|
||||
PORTER_REPO_NAME: ${{ github.event.repository.name }}
|
||||
PORTER_STACK_NAME: raise
|
||||
PORTER_TAG: ${{ steps.vars.outputs.sha_short }}
|
||||
PORTER_TOKEN: ${{ secrets.PORTER_STACK_10983_3877 }}
|
70
.github/workflows/prebuild-images.yml
vendored
70
.github/workflows/prebuild-images.yml
vendored
@ -1,70 +0,0 @@
|
||||
name: Prebuild & Deploy Docker Images
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
# paths:
|
||||
# - "backend/**"
|
||||
|
||||
env:
|
||||
AWS_REGION: eu-west-3
|
||||
ECR_REPOSITORY: backend
|
||||
ECR_REGISTRY: public.ecr.aws/c2l8c5w6
|
||||
|
||||
jobs:
|
||||
build_and_push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
environment: production
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@5fd3084fc36e372ff1fff382a39b10d03659f355 # v2
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@2fc7aceee09e9e4a7105c0d060c656fad0b4f63d # v1
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@465a07811f14bebb1938fbed4728c6a1ff8901fc # v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d # v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@f95db51fddba0c2d1ec667646a06c2ce06100226 # v3
|
||||
|
||||
- name: Create Docker Cacha Storage Backend
|
||||
run: |
|
||||
docker buildx create --use --driver=docker-container
|
||||
- name: See the file in the runner
|
||||
run: |
|
||||
ls -la
|
||||
- name: Build, tag, and push image to Amazon ECR
|
||||
id: build-image
|
||||
uses: docker/build-push-action@4a13e500e55cf31b7a5d59a38ab2040ab0f42f56 # v5
|
||||
env:
|
||||
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
||||
IMAGE_TAG: ${{ github.sha }}
|
||||
with:
|
||||
context: ./backend/
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
# tags: ${{ env.ECR_REGISTRY }}/${{ env.ECR_REPOSITORY }}:${{ env.IMAGE_TAG }}, ${{ env.ECR_REGISTRY }}/${{ env.ECR_REPOSITORY }}:latest, ghcr.io/quivrhq/quivr:latest, stangirard/quivr-backend-prebuilt:latest, stangirard/quivr-backend-prebuilt:${{ env.IMAGE_TAG }}
|
||||
tags: ghcr.io/quivrhq/quivr:latest, stangirard/quivr-backend-prebuilt:latest, stangirard/quivr-backend-prebuilt:${{ env.IMAGE_TAG }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
8
.github/workflows/release-please-core.yml
vendored
8
.github/workflows/release-please-core.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
release-please:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
release_created: ${{ steps.release.outputs['backend/core--release_created'] }}
|
||||
release_created: ${{ steps.release.outputs['core--release_created'] }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
@ -30,7 +30,7 @@ jobs:
|
||||
id: release
|
||||
uses: google-github-actions/release-please-action@v4
|
||||
with:
|
||||
path: backend/core
|
||||
path: core
|
||||
token: ${{ secrets.RELEASE_PLEASE_TOKEN }}
|
||||
|
||||
|
||||
@ -40,14 +40,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: backend/
|
||||
working-directory:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rye
|
||||
uses: eifinger/setup-rye@v2
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: backend/
|
||||
working-directory: core/
|
||||
- name: Rye Sync
|
||||
run: cd core/ && UV_INDEX_STRATEGY=unsafe-first-match rye sync --no-lock
|
||||
- name: Rye Build
|
||||
|
22
.github/workflows/release-please.yml
vendored
22
.github/workflows/release-please.yml
vendored
@ -1,22 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
name: release-please
|
||||
|
||||
jobs:
|
||||
release-please:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: google-github-actions/release-please-action@db8f2c60ee802b3748b512940dde88eabd7b7e01 # v3
|
||||
with:
|
||||
release-type: node
|
||||
changelog-notes-type: github
|
||||
package-name: release-please-action
|
||||
bump-patch-for-minor-pre-major: true
|
||||
token: ${{ secrets.RELEASE_PLEASE_TOKEN }}
|
46
.github/workflows/test-build-image.yml
vendored
46
.github/workflows/test-build-image.yml
vendored
@ -1,46 +0,0 @@
|
||||
name: Test Build Image
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [linux/amd64, linux/arm64]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@465a07811f14bebb1938fbed4728c6a1ff8901fc # v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- uses: docker/setup-qemu-action@68827325e0b33c7199eb31dd4e31fbe9023e06e3 # v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@f95db51fddba0c2d1ec667646a06c2ce06100226 # v3
|
||||
|
||||
- name: Create Docker Cacha Storage Backend
|
||||
run: |
|
||||
docker buildx create --use --driver=docker-container
|
||||
|
||||
- name: Build image
|
||||
id: build-image
|
||||
uses: docker/build-push-action@4a13e500e55cf31b7a5d59a38ab2040ab0f42f56 # v5
|
||||
env:
|
||||
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
|
||||
IMAGE_TAG: ${{ github.sha }}
|
||||
with:
|
||||
context: ./backend/
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
23
.github/workflows/vercel-preview.yml
vendored
23
.github/workflows/vercel-preview.yml
vendored
@ -1,23 +0,0 @@
|
||||
name: Preview Deployment
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
jobs:
|
||||
Deploy-Preview:
|
||||
environment: production
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
with:
|
||||
submodules: 'true'
|
||||
- name: Install Vercel CLI
|
||||
run: npm install --global vercel@latest
|
||||
- name: Pull Vercel Environment Information
|
||||
run: vercel pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }}
|
||||
- name: Build Project Artifacts
|
||||
run: vercel build --token=${{ secrets.VERCEL_TOKEN }}
|
||||
- name: Deploy Project Artifacts to Vercel
|
||||
run: vercel deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }}
|
26
.github/workflows/vercel.yml
vendored
26
.github/workflows/vercel.yml
vendored
@ -1,26 +0,0 @@
|
||||
name: Production Tag Deployment
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
|
||||
on:
|
||||
push:
|
||||
# Pattern matched against refs/tags
|
||||
tags:
|
||||
- '*' # Push events to every tag not containing /
|
||||
jobs:
|
||||
|
||||
Deploy-Production:
|
||||
environment: production
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4
|
||||
with:
|
||||
submodules: 'true'
|
||||
- name: Install Vercel CLI
|
||||
run: npm install --global vercel@latest
|
||||
- name: Pull Vercel Environment Information
|
||||
run: vercel pull --yes --environment=production --token=${{ secrets.VERCEL_TOKEN }}
|
||||
- name: Build Project Artifacts
|
||||
run: vercel build --prod --token=${{ secrets.VERCEL_TOKEN }}
|
||||
- name: Deploy Project Artifacts to Vercel
|
||||
run: vercel deploy --prebuilt --prod --token=${{ secrets.VERCEL_TOKEN }}
|
@ -1,4 +1,3 @@
|
||||
{
|
||||
"backend/core": "0.0.18",
|
||||
".": "0.0.322"
|
||||
"core": "0.0.18"
|
||||
}
|
105
.run_tests.sh
105
.run_tests.sh
@ -1,105 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
## TESTS SUITES
|
||||
test_suites=(
|
||||
"Backend Core:cd backend/core && tox -p auto"
|
||||
"Worker:cd backend && pytest worker"
|
||||
"API:cd backend && pytest api"
|
||||
)
|
||||
|
||||
# Check if gum is installed
|
||||
if ! command -v gum &>/dev/null; then
|
||||
echo "gum is not installed. Please install it with 'brew install gum'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
root_dir=$(pwd)
|
||||
|
||||
# Function to check if Tika server is running
|
||||
check_tika_server() {
|
||||
if nc -z localhost 9998 >/dev/null 2>&1; then
|
||||
return 0
|
||||
else
|
||||
gum style --foreground 196 "Error: Tika server is not running on port 9998."
|
||||
gum style --foreground 226 "Please start the Tika server before running the tests."
|
||||
gum style --foreground 226 "Run 'docker run -d -p 9998:9998 apache/tika' to start the Tika server."
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# select test suites to run, either all or one of the following
|
||||
get_test_suites_to_run() {
|
||||
gum style --bold "Select test suites to run:"
|
||||
options=("All" "${test_suites[@]%%:*}")
|
||||
selected=$(gum choose "${options[@]}")
|
||||
if [[ "$selected" == "All" ]]; then
|
||||
gum style --bold "Running all test suites"
|
||||
else
|
||||
# Find the matching test suite
|
||||
for suite in "${test_suites[@]}"; do
|
||||
if [[ "${suite%%:*}" == "$selected" ]]; then
|
||||
test_suites=("$suite")
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to run a single test suite
|
||||
run_test_suite() {
|
||||
local suite_name=$1
|
||||
local command=$2
|
||||
local exit_code
|
||||
|
||||
gum style --border normal --border-foreground 99 --padding "1 2" --bold "$suite_name Tests"
|
||||
eval "$command"
|
||||
exit_code=$?
|
||||
cd "$root_dir"
|
||||
|
||||
if [ $exit_code -eq 0 ]; then
|
||||
gum style --foreground 46 "$suite_name Tests: PASSED"
|
||||
else
|
||||
gum style --foreground 196 "$suite_name Tests: FAILED"
|
||||
fi
|
||||
|
||||
return $exit_code
|
||||
}
|
||||
|
||||
run_tests() {
|
||||
get_test_suites_to_run
|
||||
# gum spin --spinner dot --title "Running tests..." -- sleep 1
|
||||
|
||||
local all_passed=true
|
||||
local results=()
|
||||
|
||||
for suite in "${test_suites[@]}"; do
|
||||
IFS=':' read -r suite_name suite_command <<< "$suite"
|
||||
if ! run_test_suite "$suite_name" "$suite_command"; then
|
||||
all_passed=false
|
||||
fi
|
||||
results+=("$suite_name:$?")
|
||||
done
|
||||
|
||||
# Print summary of test results
|
||||
gum style --border double --border-foreground 99 --padding "1 2" --bold "Test Summary"
|
||||
for result in "${results[@]}"; do
|
||||
IFS=':' read -r suite_name exit_code <<< "$result"
|
||||
if [ "$exit_code" -eq 0 ]; then
|
||||
gum style --foreground 46 "✓ $suite_name: PASSED"
|
||||
else
|
||||
gum style --foreground 196 "✗ $suite_name: FAILED"
|
||||
fi
|
||||
done
|
||||
|
||||
# Return overall exit code
|
||||
$all_passed
|
||||
}
|
||||
|
||||
# Main execution
|
||||
if check_tika_server; then
|
||||
run_tests
|
||||
exit $?
|
||||
else
|
||||
exit 1
|
||||
fi
|
34
Makefile
34
Makefile
@ -1,34 +0,0 @@
|
||||
.DEFAULT_TARGET=help
|
||||
|
||||
## help: Display list of commands
|
||||
.PHONY: help
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@sed -n 's|^##||p' $(MAKEFILE_LIST) | column -t -s ':' | sed -e 's|^| |'
|
||||
|
||||
|
||||
## dev: Start development environment
|
||||
.PHONY: dev
|
||||
dev:
|
||||
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml up --build
|
||||
|
||||
dev-build:
|
||||
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml build --no-cache
|
||||
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml up
|
||||
|
||||
## prod: Build and start production environment
|
||||
.PHONY: prod
|
||||
prod:
|
||||
docker compose -f docker-compose.yml up --build
|
||||
|
||||
## front: Build and start frontend
|
||||
.PHONY: front
|
||||
front:
|
||||
cd frontend && yarn && yarn build && yarn start
|
||||
|
||||
## test: Run tests
|
||||
.PHONY: test
|
||||
test:
|
||||
# Ensure dependencies are installed with dev and test extras
|
||||
# poetry install --with dev,test && brew install tesseract pandoc libmagic
|
||||
./.run_tests.sh
|
11
Pipfile
11
Pipfile
@ -1,11 +0,0 @@
|
||||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.11"
|
@ -1,14 +0,0 @@
|
||||
**/.mypy_cache
|
||||
**/__pycache__
|
||||
*/.pytest_cache
|
||||
**/__pycache__
|
||||
**/.benchmarks/
|
||||
**/.cache/
|
||||
**/.pytest_cache/
|
||||
**/.next/
|
||||
**/build/
|
||||
**/.docusaurus/
|
||||
**/node_modules/
|
||||
**/.venv/
|
||||
**/.tox/
|
||||
**/.tox-docker/
|
24
backend/.vscode/launch.json
vendored
24
backend/.vscode/launch.json
vendored
@ -1,24 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Remote Attach",
|
||||
"type": "python",
|
||||
"request": "attach",
|
||||
"connect": {
|
||||
"host": "localhost",
|
||||
"port": 5678
|
||||
},
|
||||
"pathMappings": [
|
||||
{
|
||||
"localRoot": "${workspaceFolder}",
|
||||
"remoteRoot": "."
|
||||
}
|
||||
],
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
FROM python:3.11.6-slim-bullseye
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get clean && apt-get update && apt-get install -y \
|
||||
libgeos-dev \
|
||||
libcurl4-openssl-dev \
|
||||
libssl-dev \
|
||||
binutils \
|
||||
curl \
|
||||
git \
|
||||
autoconf \
|
||||
automake \
|
||||
build-essential \
|
||||
libtool \
|
||||
python-dev \
|
||||
build-essential \
|
||||
wget \
|
||||
# Additional dependencies for document handling
|
||||
libmagic-dev \
|
||||
poppler-utils \
|
||||
tesseract-ocr \
|
||||
libreoffice \
|
||||
libpq-dev \
|
||||
gcc \
|
||||
libhdf5-serial-dev \
|
||||
pandoc && \
|
||||
rm -rf /var/lib/apt/lists/* && apt-get clean
|
||||
|
||||
# Install Supabase CLI
|
||||
RUN ARCHITECTURE=$(uname -m) && \
|
||||
if [ "$ARCHITECTURE" = "x86_64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_amd64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_amd64.deb && \
|
||||
rm supabase_1.163.6_linux_amd64.deb; \
|
||||
elif [ "$ARCHITECTURE" = "aarch64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_arm64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_arm64.deb && \
|
||||
rm supabase_1.163.6_linux_arm64.deb; \
|
||||
fi
|
||||
|
||||
COPY requirements.lock pyproject.toml README.md ./
|
||||
COPY api/pyproject.toml api/README.md ./api/
|
||||
COPY api/quivr_api/__init__.py ./api/quivr_api/__init__.py
|
||||
COPY core/pyproject.toml core/README.md ./core/
|
||||
COPY core/quivr_core/__init__.py ./core/quivr_core/__init__.py
|
||||
COPY worker/pyproject.toml worker/README.md ./worker/
|
||||
COPY worker/quivr_worker/__init__.py ./worker/quivr_worker/__init__.py
|
||||
COPY worker/diff-assistant/pyproject.toml worker/diff-assistant/README.md ./worker/diff-assistant/
|
||||
COPY worker/diff-assistant/quivr_diff_assistant/__init__.py ./worker/diff-assistant/quivr_diff_assistant/__init__.py
|
||||
|
||||
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -r requirements.lock
|
||||
|
||||
RUN playwright install --with-deps && \
|
||||
python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" && \
|
||||
python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
COPY . .
|
||||
EXPOSE 5050
|
@ -1,48 +0,0 @@
|
||||
# Using a slim version for a smaller base image
|
||||
FROM python:3.11.6-slim-bullseye
|
||||
|
||||
WORKDIR /app
|
||||
# Install GEOS library, Rust, and other dependencies, then clean up
|
||||
RUN apt-get clean && apt-get update && apt-get install -y \
|
||||
libgeos-dev \
|
||||
libcurl4-openssl-dev \
|
||||
libssl-dev \
|
||||
binutils \
|
||||
curl \
|
||||
git \
|
||||
autoconf \
|
||||
automake \
|
||||
build-essential \
|
||||
libtool \
|
||||
python-dev \
|
||||
build-essential \
|
||||
# Additional dependencies for document handling
|
||||
libmagic-dev \
|
||||
poppler-utils \
|
||||
tesseract-ocr \
|
||||
libreoffice \
|
||||
libpq-dev \
|
||||
gcc \
|
||||
libhdf5-serial-dev \
|
||||
pandoc && \
|
||||
rm -rf /var/lib/apt/lists/* && apt-get clean
|
||||
|
||||
COPY requirements.lock pyproject.toml README.md ./
|
||||
COPY api/pyproject.toml api/README.md ./api/
|
||||
COPY api/quivr_api/__init__.py ./api/quivr_api/__init__.py
|
||||
COPY core/pyproject.toml core/README.md ./core/
|
||||
COPY core/quivr_core/__init__.py ./core/quivr_core/__init__.py
|
||||
COPY worker/pyproject.toml worker/README.md ./worker/
|
||||
COPY worker/quivr_worker/__init__.py ./worker/quivr_worker/__init__.py
|
||||
COPY worker/diff-assistant/pyproject.toml worker/diff-assistant/README.md ./worker/diff-assistant/
|
||||
COPY worker/diff-assistant/quivr_diff_assistant/__init__.py ./worker/diff-assistant/quivr_diff_assistant/__init__.py
|
||||
|
||||
RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -r requirements.lock
|
||||
|
||||
RUN playwright install --with-deps && \
|
||||
python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" && \
|
||||
python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 5050
|
@ -1,42 +0,0 @@
|
||||
# Using a slim version for a smaller base image
|
||||
FROM python:3.11.6-slim-bullseye
|
||||
|
||||
ARG DEV_MODE
|
||||
ENV DEV_MODE=$DEV_MODE
|
||||
|
||||
RUN apt-get clean && apt-get update && apt-get install -y wget curl
|
||||
|
||||
RUN ARCHITECTURE=$(uname -m) && \
|
||||
if [ "$ARCHITECTURE" = "x86_64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_amd64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_amd64.deb && \
|
||||
rm supabase_1.163.6_linux_amd64.deb; \
|
||||
elif [ "$ARCHITECTURE" = "aarch64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_arm64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_arm64.deb && \
|
||||
rm supabase_1.163.6_linux_arm64.deb; \
|
||||
fi && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# TODO(@aminediro) : multistage build. Probably dont neet poetry once its built
|
||||
# Install Poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | POETRY_HOME=/opt/poetry python && \
|
||||
cd /usr/local/bin && \
|
||||
ln -s /opt/poetry/bin/poetry && \
|
||||
poetry config virtualenvs.create false && \
|
||||
poetry config virtualenvs.in-project false
|
||||
|
||||
# Add Rust binaries to the PATH
|
||||
ENV PATH="/root/.cargo/bin:${PATH}" \
|
||||
POETRY_CACHE_DIR=/tmp/poetry_cache \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
POETRY_VIRTUALENVS_PATH=/code/api/.venv-docker
|
||||
|
||||
WORKDIR /code/api
|
||||
COPY . /code/
|
||||
|
||||
RUN poetry install && rm -rf $POETRY_CACHE_DIR
|
||||
|
||||
ENV PYTHONPATH=/code/api
|
||||
|
||||
EXPOSE 5050
|
@ -1,45 +0,0 @@
|
||||
# Using a slim version for a smaller base image
|
||||
FROM python:3.11.6-slim-bullseye
|
||||
|
||||
ARG DEV_MODE
|
||||
ENV DEV_MODE=$DEV_MODE
|
||||
|
||||
RUN apt-get clean && apt-get update && apt-get install -y wget curl
|
||||
|
||||
RUN ARCHITECTURE=$(uname -m) && \
|
||||
if [ "$ARCHITECTURE" = "x86_64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_amd64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_amd64.deb && \
|
||||
rm supabase_1.163.6_linux_amd64.deb; \
|
||||
elif [ "$ARCHITECTURE" = "aarch64" ]; then \
|
||||
wget https://github.com/supabase/cli/releases/download/v1.163.6/supabase_1.163.6_linux_arm64.deb && \
|
||||
dpkg -i supabase_1.163.6_linux_arm64.deb && \
|
||||
rm supabase_1.163.6_linux_arm64.deb; \
|
||||
fi && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# TODO(@aminediro) : multistage build. Probably dont neet poetry once its built
|
||||
# Install Poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | POETRY_HOME=/opt/poetry python && \
|
||||
cd /usr/local/bin && \
|
||||
ln -s /opt/poetry/bin/poetry && \
|
||||
poetry config virtualenvs.create false && \
|
||||
poetry config virtualenvs.in-project false
|
||||
|
||||
# Add Rust binaries to the PATH
|
||||
ENV PATH="/root/.cargo/bin:${PATH}" \
|
||||
POETRY_CACHE_DIR=/tmp/poetry_cache \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
POETRY_VIRTUALENVS_PATH=/code/api/.venv-docker
|
||||
|
||||
WORKDIR /code/api
|
||||
COPY api/pyproject.toml api/poetry.lock api/README.md /code/api/
|
||||
COPY api/quivr_api /code/api/quivr_api
|
||||
|
||||
# Run install
|
||||
# Run install
|
||||
RUN poetry install --no-directory --no-root --with dev && rm -rf $POETRY_CACHE_DIR
|
||||
|
||||
ENV PYTHONPATH=/code/api
|
||||
|
||||
EXPOSE 5050
|
@ -1 +0,0 @@
|
||||
# quivr-api 0.1
|
@ -1,65 +0,0 @@
|
||||
[project]
|
||||
name = "quivr-api"
|
||||
version = "0.1.0"
|
||||
description = "quivr backend API"
|
||||
authors = [{ name = "Stan Girard", email = "stan@quivr.app" }]
|
||||
dependencies = [
|
||||
"quivr-core",
|
||||
"supabase>=2.0.0",
|
||||
"fastapi>=0.100.0",
|
||||
"uvloop>=0.18.0",
|
||||
"python-jose>=3.0.0",
|
||||
"python-multipart>=0.0.9",
|
||||
"uvicorn>=0.25.0",
|
||||
"redis>=5.0.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"psycopg2-binary>=2.9.9",
|
||||
"sqlmodel>=0.0.21",
|
||||
"celery[redis]>=5.4.0",
|
||||
"pydantic-settings>=2.4.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"unidecode>=1.3.8",
|
||||
"colorlog>=6.8.2",
|
||||
"posthog>=3.5.0",
|
||||
"pyinstrument>=4.7.2",
|
||||
"sentry-sdk[fastapi]>=2.13.0",
|
||||
"google-api-python-client>=2.141.0",
|
||||
"google-auth-httplib2>=0.2.0",
|
||||
"google-auth-oauthlib>=1.2.1",
|
||||
"dropbox>=12.0.2",
|
||||
"msal>=1.30.0",
|
||||
"notion-client>=2.2.1",
|
||||
"markdownify>=0.13.1",
|
||||
"langchain-openai>=0.1.21",
|
||||
"resend>=2.4.0",
|
||||
"langchain>=0.2.14,<0.3.0",
|
||||
"litellm>=1.43.15",
|
||||
"openai>=1.40.8",
|
||||
"tiktoken>=0.7.0",
|
||||
"langchain-community>=0.2.12",
|
||||
"langchain-cohere>=0.2.2",
|
||||
"llama-parse>=0.4.9",
|
||||
"pgvector>=0.3.2",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = "< 3.12"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = []
|
||||
universal = true
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["quivr_api"]
|
||||
|
||||
|
||||
[[tool.rye.sources]]
|
||||
name = "quivr-core"
|
||||
path = "../quivr-core"
|
@ -1,15 +0,0 @@
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
|
||||
from .modules.chat.entity.chat import Chat, ChatHistory
|
||||
from .modules.sync.entity.sync_models import NotionSyncFile
|
||||
from .modules.user.entity.user_identity import User
|
||||
|
||||
__all__ = [
|
||||
"Chat",
|
||||
"ChatHistory",
|
||||
"User",
|
||||
"NotionSyncFile",
|
||||
"KnowledgeDB",
|
||||
"Brain",
|
||||
]
|
@ -1,28 +0,0 @@
|
||||
# celery_config.py
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
from celery import Celery
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "")
|
||||
CELERY_BROKER_QUEUE_NAME = os.getenv("CELERY_BROKER_QUEUE_NAME", "quivr")
|
||||
|
||||
celery = Celery(__name__)
|
||||
|
||||
if CELERY_BROKER_URL.startswith("redis"):
|
||||
celery = Celery(
|
||||
__name__,
|
||||
broker=f"{CELERY_BROKER_URL}",
|
||||
backend=f"{CELERY_BROKER_URL}",
|
||||
task_concurrency=4,
|
||||
worker_prefetch_multiplier=2,
|
||||
task_serializer="json",
|
||||
result_extended=True,
|
||||
task_send_sent_event=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported broker URL: {CELERY_BROKER_URL}")
|
||||
|
||||
celery.autodiscover_tasks(["quivr_api.modules.chat"])
|
@ -1,247 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import List
|
||||
|
||||
import orjson
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
from quivr_api.models.settings import parseable_settings
|
||||
|
||||
# Thread-safe queue for log messages
|
||||
log_queue = queue.Queue()
|
||||
stop_log_queue = threading.Event()
|
||||
|
||||
|
||||
class ParseableLogHandler(logging.Handler):
|
||||
def __init__(
|
||||
self,
|
||||
base_parseable_url: str,
|
||||
auth_token: str,
|
||||
stream_name: str,
|
||||
batch_size: int = 10,
|
||||
flush_interval: float = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.base_url = base_parseable_url
|
||||
self.stream_name = stream_name
|
||||
self.url = self.base_url + self.stream_name
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self._worker_thread = threading.Thread(target=self._process_log_queue)
|
||||
self._worker_thread.daemon = True
|
||||
self._worker_thread.start()
|
||||
self.headers = {
|
||||
"Authorization": f"Basic {auth_token}", # base64 encoding user:mdp
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
# FIXME (@AmineDiro): This ping-pong of serialization/deserialization is a limitation of logging formatter
|
||||
# The formatter should return a 'str' for the logger to print
|
||||
if isinstance(record.msg, str):
|
||||
return
|
||||
elif isinstance(record.msg, dict):
|
||||
logger_name = record.msg.get("logger", None)
|
||||
if logger_name and (
|
||||
logger_name.startswith("quivr_api.access")
|
||||
or logger_name.startswith("quivr_api.error")
|
||||
):
|
||||
url = record.msg.get("url", None)
|
||||
# Filter on healthz
|
||||
if url and "healthz" not in url:
|
||||
fmt = orjson.loads(self.format(record))
|
||||
log_queue.put(fmt)
|
||||
else:
|
||||
return
|
||||
|
||||
def _process_log_queue(self):
|
||||
"""Background thread that processes the log queue and sends logs to Parseable."""
|
||||
logs_batch = []
|
||||
while not stop_log_queue.is_set():
|
||||
try:
|
||||
# Collect logs for batch processing
|
||||
log_data = log_queue.get(timeout=self.flush_interval)
|
||||
logs_batch.append(log_data)
|
||||
|
||||
# Send logs if batch size is reached
|
||||
if len(logs_batch) >= self.batch_size:
|
||||
self._send_logs_to_parseable(logs_batch)
|
||||
logs_batch.clear()
|
||||
|
||||
except queue.Empty:
|
||||
# If the queue is empty, send any remaining logs
|
||||
if logs_batch:
|
||||
self._send_logs_to_parseable(logs_batch)
|
||||
logs_batch.clear()
|
||||
|
||||
def _send_logs_to_parseable(self, logs: List[str]):
|
||||
payload = orjson.dumps(logs)
|
||||
try:
|
||||
response = requests.post(self.url, headers=self.headers, data=payload)
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to send logs to Parseable server: {response.text}")
|
||||
except Exception as e:
|
||||
print(f"Error sending logs to Parseable: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background worker thread and process any remaining logs."""
|
||||
stop_log_queue.set()
|
||||
self._worker_thread.join()
|
||||
# Process remaining logs before shutting down
|
||||
remaining_logs = list(log_queue.queue)
|
||||
if remaining_logs:
|
||||
self._send_logs_to_parseable(remaining_logs)
|
||||
|
||||
|
||||
def extract_from_record(_, __, event_dict):
|
||||
"""
|
||||
Extract thread and process names and add them to the event dict.
|
||||
"""
|
||||
record = event_dict["_record"]
|
||||
event_dict["thread_name"] = record.threadName
|
||||
event_dict["process_name"] = record.processName
|
||||
return event_dict
|
||||
|
||||
|
||||
def drop_http_context(_, __, event_dict):
|
||||
"""
|
||||
Extract thread and process names and add them to the event dict.
|
||||
"""
|
||||
keys = ["msg", "logger", "level", "timestamp", "exc_info"]
|
||||
return {k: event_dict.get(k, None) for k in keys}
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_file="application.log", send_log_server: bool = parseable_settings.use_parseable
|
||||
):
|
||||
structlog.reset_defaults()
|
||||
# Shared handlers
|
||||
shared_processors = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.processors.EventRenamer("msg"),
|
||||
]
|
||||
structlog.configure(
|
||||
processors=shared_processors
|
||||
+ [
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
# Use standard logging compatible logger
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
# Use Python's logging configuration
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
# Set Formatters
|
||||
plain_fmt = structlog.stdlib.ProcessorFormatter(
|
||||
foreign_pre_chain=shared_processors,
|
||||
processors=[
|
||||
extract_from_record,
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.dev.ConsoleRenderer(
|
||||
colors=False, exception_formatter=structlog.dev.plain_traceback
|
||||
),
|
||||
],
|
||||
)
|
||||
color_fmt = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
drop_http_context,
|
||||
structlog.dev.ConsoleRenderer(
|
||||
colors=True,
|
||||
exception_formatter=structlog.dev.RichTracebackFormatter(
|
||||
show_locals=False
|
||||
),
|
||||
),
|
||||
],
|
||||
foreign_pre_chain=shared_processors,
|
||||
)
|
||||
parseable_fmt = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
# TODO: Which one gets us the better debug experience ?
|
||||
# structlog.processors.ExceptionRenderer(
|
||||
# exception_formatter=structlog.tracebacks.ExceptionDictTransformer(
|
||||
# show_locals=False
|
||||
# )
|
||||
# ),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
foreign_pre_chain=shared_processors
|
||||
+ [
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
{
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Set handlers
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file, maxBytes=5000000, backupCount=5
|
||||
) # 5MB file
|
||||
console_handler.setFormatter(color_fmt)
|
||||
file_handler.setFormatter(plain_fmt)
|
||||
handlers: list[logging.Handler] = [console_handler, file_handler]
|
||||
if (
|
||||
send_log_server
|
||||
and parseable_settings.parseable_url is not None
|
||||
and parseable_settings.parseable_auth is not None
|
||||
and parseable_settings.parseable_stream_name
|
||||
):
|
||||
parseable_handler = ParseableLogHandler(
|
||||
auth_token=parseable_settings.parseable_auth,
|
||||
base_parseable_url=parseable_settings.parseable_url,
|
||||
stream_name=parseable_settings.parseable_stream_name,
|
||||
)
|
||||
parseable_handler.setFormatter(parseable_fmt)
|
||||
handlers.append(parseable_handler)
|
||||
|
||||
# Configure logger
|
||||
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
root_logger.handlers = []
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
_clear_uvicorn_logger()
|
||||
|
||||
|
||||
def _clear_uvicorn_logger():
|
||||
for _log in [
|
||||
"uvicorn",
|
||||
"httpcore",
|
||||
"uvicorn.error",
|
||||
"uvicorn.access",
|
||||
"urllib3",
|
||||
"httpx",
|
||||
]:
|
||||
# Clear the log handlers for uvicorn loggers, and enable propagation
|
||||
# so the messages are caught by our root logger and formatted correctly
|
||||
# by structlog
|
||||
logging.getLogger(_log).setLevel(logging.WARNING)
|
||||
logging.getLogger(_log).handlers.clear()
|
||||
logging.getLogger(_log).propagate = True
|
||||
|
||||
|
||||
setup_logger()
|
||||
|
||||
|
||||
def get_logger(name: str | None = None):
|
||||
assert structlog.is_configured()
|
||||
return structlog.get_logger(name)
|
@ -1,122 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import sentry_sdk
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pyinstrument import Profiler
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
|
||||
from quivr_api.logger import get_logger, stop_log_queue
|
||||
from quivr_api.middlewares.cors import add_cors_middleware
|
||||
from quivr_api.middlewares.logging_middleware import LoggingMiddleware
|
||||
from quivr_api.modules.analytics.controller.analytics_routes import analytics_router
|
||||
from quivr_api.modules.api_key.controller import api_key_router
|
||||
from quivr_api.modules.assistant.controller import assistant_router
|
||||
from quivr_api.modules.brain.controller import brain_router
|
||||
from quivr_api.modules.chat.controller import chat_router
|
||||
from quivr_api.modules.knowledge.controller import knowledge_router
|
||||
from quivr_api.modules.misc.controller import misc_router
|
||||
from quivr_api.modules.models.controller.model_routes import model_router
|
||||
from quivr_api.modules.onboarding.controller import onboarding_router
|
||||
from quivr_api.modules.prompt.controller import prompt_router
|
||||
from quivr_api.modules.sync.controller import sync_router
|
||||
from quivr_api.modules.upload.controller import upload_router
|
||||
from quivr_api.modules.user.controller import user_router
|
||||
from quivr_api.routes.crawl_routes import crawl_router
|
||||
from quivr_api.routes.subscription_routes import subscription_router
|
||||
from quivr_api.utils.telemetry import maybe_send_telemetry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger("quivr_api")
|
||||
|
||||
|
||||
def before_send(event, hint):
|
||||
# If this is a transaction event
|
||||
if event["type"] == "transaction":
|
||||
# And the transaction name contains 'healthz'
|
||||
if "healthz" in event["transaction"]:
|
||||
# Drop the event by returning None
|
||||
return None
|
||||
# For other events, return them as is
|
||||
return event
|
||||
|
||||
|
||||
sentry_dsn = os.getenv("SENTRY_DSN")
|
||||
if sentry_dsn:
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
sample_rate=0.1,
|
||||
enable_tracing=True,
|
||||
traces_sample_rate=0.1,
|
||||
integrations=[
|
||||
StarletteIntegration(transaction_style="url"),
|
||||
FastApiIntegration(transaction_style="url"),
|
||||
],
|
||||
before_send=before_send,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
add_cors_middleware(app)
|
||||
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
|
||||
app.include_router(brain_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(crawl_router)
|
||||
app.include_router(assistant_router)
|
||||
app.include_router(sync_router)
|
||||
app.include_router(onboarding_router)
|
||||
app.include_router(misc_router)
|
||||
app.include_router(analytics_router)
|
||||
app.include_router(upload_router)
|
||||
app.include_router(user_router)
|
||||
app.include_router(api_key_router)
|
||||
app.include_router(subscription_router)
|
||||
app.include_router(prompt_router)
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(model_router)
|
||||
|
||||
PROFILING = os.getenv("PROFILING", "false").lower() == "true"
|
||||
|
||||
|
||||
if PROFILING:
|
||||
|
||||
@app.middleware("http")
|
||||
async def profile_request(request: Request, call_next):
|
||||
profiling = request.query_params.get("profile", False)
|
||||
if profiling:
|
||||
profiler = Profiler()
|
||||
profiler.start()
|
||||
await call_next(request)
|
||||
profiler.stop()
|
||||
return HTMLResponse(profiler.output_html())
|
||||
else:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
stop_log_queue.set()
|
||||
|
||||
|
||||
if os.getenv("TELEMETRY_ENABLED") == "true":
|
||||
logger.info("Telemetry enabled, we use telemetry to collect anonymous usage data.")
|
||||
logger.info(
|
||||
"To disable telemetry, set the TELEMETRY_ENABLED environment variable to false."
|
||||
)
|
||||
maybe_send_telemetry("booting", {"status": "ok"})
|
||||
maybe_send_telemetry("ping", {"ping": "pong"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# run main.py to debug backend
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5050, log_level="debug", access_log=False)
|
@ -1,6 +0,0 @@
|
||||
from .auth_bearer import AuthBearer, get_current_user
|
||||
|
||||
__all__ = [
|
||||
"AuthBearer",
|
||||
"get_current_user",
|
||||
]
|
@ -1,78 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from quivr_api.middlewares.auth.jwt_token_handler import (
|
||||
decode_access_token,
|
||||
verify_token,
|
||||
)
|
||||
from quivr_api.modules.api_key.service.api_key_service import ApiKeyService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
api_key_service = ApiKeyService()
|
||||
|
||||
logger = structlog.stdlib.get_logger("quivr_api.access")
|
||||
|
||||
|
||||
class AuthBearer(HTTPBearer):
|
||||
def __init__(self, auto_error: bool = True):
|
||||
super().__init__(auto_error=auto_error)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: Request,
|
||||
):
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(
|
||||
request
|
||||
)
|
||||
self.check_scheme(credentials)
|
||||
token = credentials.credentials # pyright: ignore reportPrivateUsage=none
|
||||
return await self.authenticate(
|
||||
token,
|
||||
)
|
||||
|
||||
def check_scheme(self, credentials):
|
||||
if credentials and credentials.scheme != "Bearer":
|
||||
raise HTTPException(status_code=401, detail="Token must be Bearer")
|
||||
elif not credentials:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Authentication credentials missing"
|
||||
)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
token: str,
|
||||
) -> UserIdentity:
|
||||
if os.environ.get("AUTHENTICATE") == "false":
|
||||
return self.get_test_user()
|
||||
elif verify_token(token):
|
||||
return decode_access_token(token)
|
||||
elif await api_key_service.verify_api_key(
|
||||
token,
|
||||
):
|
||||
return await api_key_service.get_user_from_api_key(
|
||||
token,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Invalid token or api key.")
|
||||
|
||||
def get_test_user(self) -> UserIdentity:
|
||||
return UserIdentity(
|
||||
email="admin@quivr.app",
|
||||
id="39418e3b-0258-4452-af60-7acfcc1263ff", # type: ignore
|
||||
) # replace with test user information
|
||||
|
||||
|
||||
auth_bearer = AuthBearer()
|
||||
|
||||
|
||||
async def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity:
|
||||
# Due to context switch in FastAPI executor we can't get this id back
|
||||
# We log it as an additional log so we can get information if exception was raised
|
||||
# https://www.structlog.org/en/stable/contextvars.html
|
||||
structlog.contextvars.bind_contextvars(client_id=str(user.id))
|
||||
logger.info("Authentication success")
|
||||
return user
|
@ -1,44 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
if not SECRET_KEY:
|
||||
raise ValueError("JWT_SECRET_KEY environment variable not set")
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> UserIdentity:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False}
|
||||
)
|
||||
except JWTError:
|
||||
return None # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
return UserIdentity(
|
||||
email=payload.get("email"),
|
||||
id=payload.get("sub"), # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
|
||||
|
||||
def verify_token(token: str):
|
||||
payload = decode_access_token(token)
|
||||
return payload is not None
|
@ -1,23 +0,0 @@
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
"https://quivr.app",
|
||||
"https://www.quivr.app",
|
||||
"http://quivr.app",
|
||||
"http://www.quivr.app",
|
||||
"https://chat.quivr.app",
|
||||
"*",
|
||||
]
|
||||
|
||||
|
||||
def add_cors_middleware(app):
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
@ -1,95 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
from fastapi import Request, Response, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from structlog.contextvars import (
|
||||
bind_contextvars,
|
||||
clear_contextvars,
|
||||
)
|
||||
|
||||
logger = structlog.stdlib.get_logger("quivr_api.access")
|
||||
|
||||
|
||||
git_sha = os.getenv("PORTER_IMAGE_TAG", None)
|
||||
|
||||
|
||||
def clean_dict(d):
|
||||
"""Remove None values from a dictionary."""
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
clear_contextvars()
|
||||
# Generate a unique request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
client_addr = (
|
||||
f"{request.client.host}:{request.client.port}" if request.client else None
|
||||
)
|
||||
url = request.url.path
|
||||
http_version = request.scope["http_version"]
|
||||
|
||||
bind_contextvars(
|
||||
**clean_dict(
|
||||
{
|
||||
"git_head": git_sha,
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_addr": client_addr,
|
||||
"request_user_agent": request.headers.get("user-agent"),
|
||||
"request_content_type": request.headers.get("content-type"),
|
||||
"url": url,
|
||||
"http_version": http_version,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start time
|
||||
start_time = time.perf_counter()
|
||||
response = Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
try:
|
||||
# Process the request
|
||||
response: Response = await call_next(request)
|
||||
process_time = time.perf_counter() - start_time
|
||||
bind_contextvars(
|
||||
**clean_dict(
|
||||
{
|
||||
"response_content_type": response.headers.get("content-type"),
|
||||
"response_status": response.status_code,
|
||||
"response_headers": dict(response.headers),
|
||||
"timing_request_total_ms": round(process_time * 1e3, 3),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"""{client_addr} - "{request.method} {url} HTTP/{http_version}" {response.status_code}""",
|
||||
)
|
||||
except Exception:
|
||||
process_time = time.perf_counter() - start_time
|
||||
bind_contextvars(
|
||||
**clean_dict(
|
||||
{
|
||||
"response_status": response.status_code,
|
||||
"timing_request_total_ms": round(process_time * 1000, 3),
|
||||
}
|
||||
)
|
||||
)
|
||||
structlog.stdlib.get_logger("quivr_api.error").exception(
|
||||
"Request failed with exception"
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
clear_contextvars()
|
||||
|
||||
# Add X-Request-ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
@ -1,14 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BrainSubscription(BaseModel):
|
||||
brain_id: UUID
|
||||
email: str
|
||||
rights: str = "Viewer"
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
@ -1,11 +0,0 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CrawlWebsite(BaseModel):
|
||||
url: str
|
||||
js: bool = False
|
||||
depth: int = int(os.getenv("CRAWL_DEPTH", "1"))
|
||||
max_pages: int = 100
|
||||
max_time: int = 60
|
@ -1,13 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMModel(BaseModel):
|
||||
"""LLM models stored in the database that are allowed to be used by the users.
|
||||
Args:
|
||||
BaseModel (BaseModel): Pydantic BaseModel
|
||||
"""
|
||||
|
||||
name: str = "gpt-3.5-turbo-0125"
|
||||
price: int = 1
|
||||
max_input: int = 512
|
||||
max_output: int = 512
|
@ -1,85 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from .llm_models import LLMModel
|
||||
|
||||
|
||||
class Repository(ABC):
|
||||
@abstractmethod
|
||||
def create_user_daily_usage(self, user_id: UUID, user_email: str, date: datetime):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_usage(self, user_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self) -> LLMModel | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_requests_count_for_month(self, user_id: UUID, date: datetime):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_user_request_count(self, user_id: UUID, date: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def increment_user_request_count(
|
||||
self, user_id: UUID, date: str, current_request_count
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_file_vectors_ids(self, file_sha1: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_brain_vectors_by_brain_id_and_file_sha1(
|
||||
self, brain_id: UUID, file_sha1: str
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_subscription_invitation(
|
||||
self, brain_id: UUID, user_email: str, rights: str
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_subscription_invitation(
|
||||
self, brain_id: UUID, user_email: str, rights: str
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_subscription_invitations_by_brain_id_and_email(
|
||||
self, brain_id: UUID, user_email: str
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors_by_file_name(self, file_name: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similarity_search(self, query_embedding, table: str, k: int, threshold: float):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_summary(self, document_id: UUID, summary_id: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors_by_batch(self, batch_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors_in_batch(self, batch_ids):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vectors_by_file_sha1(self, file_sha1):
|
||||
pass
|
@ -1,8 +0,0 @@
|
||||
from quivr_api.models.databases.supabase.brains_subscription_invitations import (
|
||||
BrainSubscription,
|
||||
)
|
||||
from quivr_api.models.databases.supabase.files import File
|
||||
from quivr_api.models.databases.supabase.user_usage import UserUsage
|
||||
from quivr_api.models.databases.supabase.vectors import Vector
|
||||
|
||||
__all__ = ["BrainSubscription", "File", "UserUsage", "Vector"]
|
@ -1,40 +0,0 @@
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.databases.repository import Repository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BrainSubscription(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def create_subscription_invitation(self, brain_id, user_email, rights):
|
||||
logger.info("Creating subscription invitation")
|
||||
response = (
|
||||
self.db.table("brain_subscription_invitations")
|
||||
.insert({"brain_id": str(brain_id), "email": user_email, "rights": rights})
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
def update_subscription_invitation(self, brain_id, user_email, rights):
|
||||
logger.info("Updating subscription invitation")
|
||||
response = (
|
||||
self.db.table("brain_subscription_invitations")
|
||||
.update({"rights": rights})
|
||||
.eq("brain_id", str(brain_id))
|
||||
.eq("email", user_email)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
def get_subscription_invitations_by_brain_id_and_email(self, brain_id, user_email):
|
||||
response = (
|
||||
self.db.table("brain_subscription_invitations")
|
||||
.select("*")
|
||||
.eq("brain_id", str(brain_id))
|
||||
.eq("email", user_email)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
@ -1,28 +0,0 @@
|
||||
from quivr_api.models.databases.repository import Repository
|
||||
|
||||
|
||||
class File(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def set_file_vectors_ids(self, file_sha1):
|
||||
response = (
|
||||
self.db.table("vectors")
|
||||
.select("id")
|
||||
.filter("file_sha1", "eq", file_sha1)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
def get_brain_vectors_by_brain_id_and_file_sha1(self, brain_id, file_sha1):
|
||||
self.set_file_vectors_ids(file_sha1)
|
||||
# Check if file exists in that brain
|
||||
response = (
|
||||
self.db.table("brains_vectors")
|
||||
.select("brain_id, vector_id")
|
||||
.filter("brain_id", "eq", str(brain_id))
|
||||
.filter("file_sha1", "eq", file_sha1)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
@ -1,21 +0,0 @@
|
||||
from quivr_api.models.databases.supabase import (
|
||||
BrainSubscription,
|
||||
File,
|
||||
UserUsage,
|
||||
Vector,
|
||||
)
|
||||
|
||||
|
||||
# TODO: REMOVE THIS CLASS !
|
||||
class SupabaseDB(
|
||||
UserUsage,
|
||||
File,
|
||||
BrainSubscription,
|
||||
Vector,
|
||||
):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
UserUsage.__init__(self, supabase_client)
|
||||
File.__init__(self, supabase_client)
|
||||
BrainSubscription.__init__(self, supabase_client)
|
||||
Vector.__init__(self, supabase_client)
|
@ -1,128 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.databases.repository import Repository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: change the name of this class because another one already exists
|
||||
class UserUsage(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def create_user_daily_usage(
|
||||
self, user_id: UUID, user_email: str, date: datetime, number: int = 1
|
||||
):
|
||||
return (
|
||||
self.db.table("user_daily_usage")
|
||||
.insert(
|
||||
{
|
||||
"user_id": str(user_id),
|
||||
"email": user_email,
|
||||
"date": date,
|
||||
"daily_requests_count": number,
|
||||
}
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
def get_user_settings(self, user_id):
|
||||
"""
|
||||
Fetch the user settings from the database
|
||||
"""
|
||||
|
||||
user_settings_response = (
|
||||
self.db.from_("user_settings")
|
||||
.select("*")
|
||||
.filter("user_id", "eq", str(user_id))
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if len(user_settings_response) == 0:
|
||||
# Create the user settings
|
||||
user_settings_response = (
|
||||
self.db.table("user_settings")
|
||||
.insert({"user_id": str(user_id)})
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if len(user_settings_response) == 0:
|
||||
raise ValueError("User settings could not be created")
|
||||
|
||||
user_settings = user_settings_response[0]
|
||||
|
||||
return user_settings
|
||||
|
||||
def get_models(self):
|
||||
model_settings_response = (self.db.from_("models").select("*").execute()).data
|
||||
if len(model_settings_response) == 0:
|
||||
raise ValueError("An issue occured while fetching the model settings")
|
||||
return model_settings_response
|
||||
|
||||
def get_user_monthly(self, user_id):
|
||||
pass
|
||||
|
||||
def get_user_usage(self, user_id):
|
||||
"""
|
||||
Fetch the user request stats from the database
|
||||
"""
|
||||
requests_stats = (
|
||||
self.db.from_("user_daily_usage")
|
||||
.select("*")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
)
|
||||
return requests_stats.data
|
||||
|
||||
def get_user_requests_count_for_day(self, user_id, date):
|
||||
"""
|
||||
Fetch the user request count from the database
|
||||
"""
|
||||
response = (
|
||||
self.db.from_("user_daily_usage")
|
||||
.select("daily_requests_count")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.filter("date", "eq", date)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if response and len(response) > 0:
|
||||
return response[0]["daily_requests_count"]
|
||||
return 0
|
||||
|
||||
def get_user_requests_count_for_month(self, user_id, date):
|
||||
"""
|
||||
Fetch the user request count from the database
|
||||
"""
|
||||
date_30_days_ago = (datetime.now() - timedelta(days=30)).strftime("%Y%m%d")
|
||||
|
||||
response = (
|
||||
self.db.from_("user_daily_usage")
|
||||
.select("daily_requests_count")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.filter("date", "gte", date_30_days_ago)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if response and len(response) > 0:
|
||||
return sum(row["daily_requests_count"] for row in response)
|
||||
return 0
|
||||
|
||||
def increment_user_request_count(self, user_id, date, number: int = 1):
|
||||
"""
|
||||
Increment the user's requests count for a specific day
|
||||
"""
|
||||
|
||||
self.update_user_request_count(user_id, daily_requests_count=number, date=date)
|
||||
|
||||
def update_user_request_count(self, user_id, daily_requests_count, date):
|
||||
response = (
|
||||
self.db.table("user_daily_usage")
|
||||
.update({"daily_requests_count": daily_requests_count})
|
||||
.match({"user_id": user_id, "date": date})
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
@ -1,76 +0,0 @@
|
||||
from quivr_api.models.databases.repository import Repository
|
||||
|
||||
|
||||
class Vector(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def get_vectors_by_file_name(self, file_name):
|
||||
response = (
|
||||
self.db.table("vectors")
|
||||
.select(
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url",
|
||||
"content",
|
||||
"brains_vectors(brain_id,vector_id)",
|
||||
)
|
||||
.match({"metadata->>file_name": file_name})
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_vectors_by_file_sha1(self, file_sha1):
|
||||
response = (
|
||||
self.db.table("vectors")
|
||||
.select("id")
|
||||
.filter("file_sha1", "eq", file_sha1)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# TODO: remove duplicate similarity_search in supabase vector store
|
||||
def similarity_search(self, query_embedding, table, k, threshold):
|
||||
response = self.db.rpc(
|
||||
table,
|
||||
{
|
||||
"query_embedding": query_embedding,
|
||||
"match_count": k,
|
||||
"match_threshold": threshold,
|
||||
},
|
||||
).execute()
|
||||
return response
|
||||
|
||||
def update_summary(self, document_id, summary_id):
|
||||
return (
|
||||
self.db.table("summaries")
|
||||
.update({"document_id": document_id})
|
||||
.match({"id": summary_id})
|
||||
.execute()
|
||||
)
|
||||
|
||||
def get_vectors_by_batch(self, batch_id):
|
||||
response = (
|
||||
self.db.table("vectors")
|
||||
.select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size",
|
||||
count="exact",
|
||||
)
|
||||
.eq("id", batch_id)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_vectors_in_batch(self, batch_ids):
|
||||
response = (
|
||||
self.db.table("vectors")
|
||||
.select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size",
|
||||
count="exact",
|
||||
)
|
||||
.in_("id", batch_ids)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return response
|
@ -1,138 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from posthog import Posthog
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class BrainRateLimiting(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
max_brain_per_user: int = 5
|
||||
|
||||
|
||||
class SendEmailSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
resend_contact_sales_from: str = "null"
|
||||
resend_contact_sales_to: str = "null"
|
||||
|
||||
|
||||
# The `PostHogSettings` class is used to initialize and interact with the PostHog analytics service.
|
||||
class PostHogSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
posthog_api_key: str | None = None
|
||||
posthog_api_url: str | None = None
|
||||
posthog: Posthog | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
The function initializes the "posthog" attribute and calls the "initialize_posthog" method.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.posthog = None
|
||||
self.initialize_posthog()
|
||||
|
||||
def initialize_posthog(self):
|
||||
"""
|
||||
The function initializes a PostHog client with an API key and URL.
|
||||
"""
|
||||
if self.posthog_api_key and self.posthog_api_url:
|
||||
self.posthog = Posthog(
|
||||
api_key=self.posthog_api_key, host=self.posthog_api_url
|
||||
)
|
||||
|
||||
def log_event(self, user_id: UUID, event_name: str, event_properties: dict):
|
||||
"""
|
||||
The function logs an event with a user ID, event name, and event properties using the PostHog
|
||||
analytics tool.
|
||||
|
||||
:param user_id: The user_id parameter is a UUID (Universally Unique Identifier) that uniquely
|
||||
identifies a user. It is typically used to track and identify individual users in an application
|
||||
or system
|
||||
:type user_id: UUID
|
||||
:param event_name: The event_name parameter is a string that represents the name or type of the
|
||||
event that you want to log. It could be something like "user_signed_up", "item_purchased", or
|
||||
"page_viewed"
|
||||
:type event_name: str
|
||||
:param event_properties: The event_properties parameter is a dictionary that contains additional
|
||||
information or properties related to the event being logged. These properties provide more
|
||||
context or details about the event and can be used for analysis or filtering purposes
|
||||
:type event_properties: dict
|
||||
"""
|
||||
if self.posthog:
|
||||
self.posthog.capture(user_id, event_name, event_properties)
|
||||
|
||||
def set_user_properties(self, user_id: UUID, event_name, properties: dict):
|
||||
"""
|
||||
The function sets user properties for a given user ID and event name using the PostHog analytics
|
||||
tool.
|
||||
|
||||
:param user_id: The user_id parameter is a UUID (Universally Unique Identifier) that uniquely
|
||||
identifies a user. It is used to associate the user with the event and properties being captured
|
||||
:type user_id: UUID
|
||||
:param event_name: The `event_name` parameter is a string that represents the name of the event
|
||||
that you want to capture. It could be something like "user_signed_up" or "item_purchased"
|
||||
:param properties: The `properties` parameter is a dictionary that contains the user properties
|
||||
that you want to set. Each key-value pair in the dictionary represents a user property, where
|
||||
the key is the name of the property and the value is the value you want to set for that property
|
||||
:type properties: dict
|
||||
"""
|
||||
if self.posthog:
|
||||
self.posthog.capture(
|
||||
user_id, event=event_name, properties={"$set": properties}
|
||||
)
|
||||
|
||||
def set_once_user_properties(self, user_id: UUID, event_name, properties: dict):
|
||||
"""
|
||||
The function sets user properties for a specific event, ensuring that the properties are only
|
||||
set once.
|
||||
|
||||
:param user_id: The user_id parameter is a UUID (Universally Unique Identifier) that uniquely
|
||||
identifies a user
|
||||
:type user_id: UUID
|
||||
:param event_name: The `event_name` parameter is a string that represents the name of the event
|
||||
that you want to capture. It could be something like "user_signed_up" or "item_purchased"
|
||||
:param properties: The `properties` parameter is a dictionary that contains the user properties
|
||||
that you want to set. Each key-value pair in the dictionary represents a user property, where
|
||||
the key is the property name and the value is the property value
|
||||
:type properties: dict
|
||||
"""
|
||||
if self.posthog:
|
||||
self.posthog.capture(
|
||||
user_id, event=event_name, properties={"$set_once": properties}
|
||||
)
|
||||
|
||||
|
||||
class BrainSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
openai_api_key: str = ""
|
||||
azure_openai_embeddings_url: str = ""
|
||||
supabase_url: str = ""
|
||||
supabase_service_key: str = ""
|
||||
resend_api_key: str = "null"
|
||||
resend_email_address: str = "brain@mail.quivr.app"
|
||||
ollama_api_base_url: str | None = None
|
||||
langfuse_public_key: str | None = None
|
||||
langfuse_secret_key: str | None = None
|
||||
pg_database_url: str
|
||||
pg_database_async_url: str
|
||||
embedding_dim: int = 1536
|
||||
|
||||
|
||||
class ResendSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
resend_api_key: str = "null"
|
||||
quivr_smtp_server: str = ""
|
||||
quivr_smtp_port: int = 587
|
||||
quivr_smtp_username: str = ""
|
||||
quivr_smtp_password: str = ""
|
||||
|
||||
|
||||
class ParseableSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
use_parseable: bool = False
|
||||
parseable_url: str | None = None
|
||||
parseable_auth: str | None = None
|
||||
parseable_stream_name: str | None = None
|
||||
|
||||
|
||||
settings = BrainSettings() # type: ignore
|
||||
parseable_settings = ParseableSettings()
|
@ -1,73 +0,0 @@
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id = Column(String, primary_key=True)
|
||||
email = Column(String)
|
||||
date = Column(DateTime)
|
||||
daily_requests_count = Column(Integer)
|
||||
|
||||
|
||||
class Brain(Base):
|
||||
__tablename__ = "brains"
|
||||
|
||||
brain_id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
users = relationship("BrainUser", back_populates="brain")
|
||||
vectors = relationship("BrainVector", back_populates="brain")
|
||||
|
||||
|
||||
class BrainUser(Base):
|
||||
__tablename__ = "brains_users"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey("users.user_id"))
|
||||
brain_id = Column(Integer, ForeignKey("brains.brain_id"))
|
||||
rights = Column(String)
|
||||
|
||||
user = relationship("User")
|
||||
brain = relationship("Brain", back_populates="users")
|
||||
|
||||
|
||||
class BrainVector(Base):
|
||||
__tablename__ = "brains_vectors"
|
||||
|
||||
vector_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
|
||||
brain_id = Column(Integer, ForeignKey("brains.brain_id"))
|
||||
file_sha1 = Column(String)
|
||||
|
||||
brain = relationship("Brain", back_populates="vectors")
|
||||
|
||||
|
||||
class BrainSubscriptionInvitation(Base):
|
||||
__tablename__ = "brain_subscription_invitations"
|
||||
|
||||
id = Column(Integer, primary_key=True) # Assuming an integer primary key named 'id'
|
||||
brain_id = Column(String, ForeignKey("brains.brain_id"))
|
||||
email = Column(String, ForeignKey("users.email"))
|
||||
rights = Column(String)
|
||||
|
||||
brain = relationship("Brain")
|
||||
user = relationship("User", foreign_keys=[email])
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
key_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
|
||||
user_id = Column(Integer, ForeignKey("users.user_id"))
|
||||
api_key = Column(String, unique=True)
|
||||
creation_time = Column(DateTime, default=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True)
|
||||
deleted_time = Column(DateTime, nullable=True)
|
||||
|
||||
user = relationship("User")
|
@ -1,25 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.modules.analytics.entity.analytics import Range
|
||||
from quivr_api.modules.analytics.service.analytics_service import AnalyticsService
|
||||
|
||||
analytics_service = AnalyticsService()
|
||||
analytics_router = APIRouter()
|
||||
|
||||
|
||||
@analytics_router.get(
|
||||
"/analytics/brains-usages", dependencies=[Depends(AuthBearer())], tags=["Analytics"]
|
||||
)
|
||||
async def get_brains_usages(
|
||||
user: UUID = Depends(get_current_user),
|
||||
brain_id: UUID = Query(None),
|
||||
graph_range: Range = Query(Range.WEEK, alias="graph_range"),
|
||||
):
|
||||
"""
|
||||
Get all user brains usages
|
||||
"""
|
||||
|
||||
return analytics_service.get_brains_usages(user.id, graph_range, brain_id)
|
@ -1,20 +0,0 @@
|
||||
from datetime import date
|
||||
from enum import IntEnum
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Range(IntEnum):
|
||||
WEEK = 7
|
||||
MONTH = 30
|
||||
QUARTER = 90
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
date: date
|
||||
usage_count: int
|
||||
|
||||
|
||||
class BrainsUsages(BaseModel):
|
||||
usages: List[Usage]
|
@ -1,56 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.analytics.entity.analytics import BrainsUsages, Range, Usage
|
||||
from quivr_api.modules.brain.service.brain_user_service import BrainUserService
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
|
||||
brain_user_service = BrainUserService()
|
||||
|
||||
|
||||
class Analytics:
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
def get_brains_usages(
|
||||
self, user_id: UUID, graph_range: Range, brain_id: Optional[UUID] = None
|
||||
) -> BrainsUsages:
|
||||
user_brains = brain_user_service.get_user_brains(user_id)
|
||||
if brain_id is not None:
|
||||
user_brains = [brain for brain in user_brains if brain.id == brain_id]
|
||||
|
||||
usage_per_day = defaultdict(int)
|
||||
|
||||
brain_ids = [brain.id for brain in user_brains]
|
||||
chat_history = (
|
||||
self.db.from_("chat_history")
|
||||
.select("*")
|
||||
.in_("brain_id", brain_ids)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
for chat in chat_history:
|
||||
message_time = datetime.strptime(
|
||||
chat["message_time"], "%Y-%m-%dT%H:%M:%S.%f"
|
||||
)
|
||||
usage_per_day[message_time.date()] += 1
|
||||
|
||||
start_date = datetime.now().date() - timedelta(days=graph_range)
|
||||
all_dates = [start_date + timedelta(days=i) for i in range(graph_range)]
|
||||
|
||||
for date in all_dates:
|
||||
usage_per_day[date] += 0
|
||||
|
||||
usages = sorted(
|
||||
[
|
||||
Usage(date=date, usage_count=count)
|
||||
for date, count in usage_per_day.items()
|
||||
if start_date <= date <= datetime.now().date()
|
||||
],
|
||||
key=lambda usage: usage.date,
|
||||
)
|
||||
|
||||
return BrainsUsages(usages=usages)
|
@ -1,22 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.analytics.entity.analytics import BrainsUsages, Range
|
||||
|
||||
|
||||
class AnalyticsInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_brains_usages(
|
||||
self,
|
||||
user_id: UUID,
|
||||
graph_range: Range = Range.WEEK,
|
||||
brain_id: Optional[UUID] = None,
|
||||
) -> BrainsUsages:
|
||||
"""
|
||||
Get user brains usage
|
||||
Args:
|
||||
user_id (UUID): The id of the user
|
||||
brain_id (Optional[UUID]): The id of the brain, optional
|
||||
"""
|
||||
pass
|
@ -1,14 +0,0 @@
|
||||
from quivr_api.modules.analytics.repository.analytics import Analytics
|
||||
from quivr_api.modules.analytics.repository.analytics_interface import (
|
||||
AnalyticsInterface,
|
||||
)
|
||||
|
||||
|
||||
class AnalyticsService:
|
||||
repository: AnalyticsInterface
|
||||
|
||||
def __init__(self):
|
||||
self.repository = Analytics()
|
||||
|
||||
def get_brains_usages(self, user_id, graph_range, brain_id=None):
|
||||
return self.repository.get_brains_usages(user_id, graph_range, brain_id)
|
@ -1 +0,0 @@
|
||||
from .api_key_routes import api_key_router
|
@ -1,92 +0,0 @@
|
||||
from secrets import token_hex
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.api_key.dto.outputs import ApiKeyInfo
|
||||
from quivr_api.modules.api_key.entity.api_key import ApiKey
|
||||
from quivr_api.modules.api_key.repository.api_keys import ApiKeys
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
api_key_router = APIRouter()
|
||||
|
||||
api_keys_repository = ApiKeys()
|
||||
|
||||
|
||||
@api_key_router.post(
|
||||
"/api-key",
|
||||
response_model=ApiKey,
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["API Key"],
|
||||
)
|
||||
async def create_api_key(current_user: UserIdentity = Depends(get_current_user)):
|
||||
"""
|
||||
Create new API key for the current user.
|
||||
|
||||
- `current_user`: The current authenticated user.
|
||||
- Returns the newly created API key.
|
||||
|
||||
This endpoint generates a new API key for the current user. The API key is stored in the database and associated with
|
||||
the user. It returns the newly created API key.
|
||||
"""
|
||||
|
||||
new_key_id = uuid4()
|
||||
new_api_key = token_hex(16)
|
||||
|
||||
try:
|
||||
# Attempt to insert new API key into database
|
||||
response = api_keys_repository.create_api_key(
|
||||
new_key_id, new_api_key, current_user.id, "api_key", 30, False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating new API key: {e}")
|
||||
return {"api_key": "Error creating new API key."}
|
||||
logger.info(f"Created new API key for user {current_user.email}.")
|
||||
|
||||
return response # type: ignore
|
||||
|
||||
|
||||
@api_key_router.delete(
|
||||
"/api-key/{key_id}", dependencies=[Depends(AuthBearer())], tags=["API Key"]
|
||||
)
|
||||
async def delete_api_key(
|
||||
key_id: str, current_user: UserIdentity = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete (deactivate) an API key for the current user.
|
||||
|
||||
- `key_id`: The ID of the API key to delete.
|
||||
|
||||
This endpoint deactivates and deletes the specified API key associated with the current user. The API key is marked
|
||||
as inactive in the database.
|
||||
|
||||
"""
|
||||
api_keys_repository.delete_api_key(key_id, current_user.id)
|
||||
|
||||
return {"message": "API key deleted."}
|
||||
|
||||
|
||||
@api_key_router.get(
|
||||
"/api-keys",
|
||||
response_model=List[ApiKeyInfo],
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["API Key"],
|
||||
)
|
||||
async def get_api_keys(current_user: UserIdentity = Depends(get_current_user)):
|
||||
"""
|
||||
Get all active API keys for the current user.
|
||||
|
||||
- `current_user`: The current authenticated user.
|
||||
- Returns a list of active API keys with their IDs and creation times.
|
||||
|
||||
This endpoint retrieves all the active API keys associated with the current user. It returns a list of API key objects
|
||||
containing the key ID and creation time for each API key.
|
||||
"""
|
||||
response = api_keys_repository.get_user_api_keys(current_user.id)
|
||||
return response.data
|
@ -1 +0,0 @@
|
||||
from .outputs import ApiKeyInfo
|
@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ApiKeyInfo(BaseModel):
|
||||
key_id: str
|
||||
creation_time: str
|
@ -1,11 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
api_key: str
|
||||
key_id: str
|
||||
days: int
|
||||
only_chat: bool
|
||||
name: str
|
||||
creation_time: str
|
||||
is_active: bool
|
@ -1,34 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.api_key.entity.api_key import ApiKey
|
||||
|
||||
|
||||
class ApiKeysInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_api_key(
|
||||
self,
|
||||
new_key_id: UUID,
|
||||
new_api_key: str,
|
||||
user_id: UUID,
|
||||
days: int,
|
||||
only_chat: bool,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_api_key(self, key_id: UUID, user_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_active_api_key(self, api_key: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_id_by_api_key(self, api_key: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_api_keys(self, user_id: UUID) -> List[ApiKey]:
|
||||
pass
|
@ -1,82 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.api_key.entity.api_key import ApiKey
|
||||
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
|
||||
|
||||
class ApiKeys(ApiKeysInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client # type: ignore
|
||||
|
||||
def create_api_key(
|
||||
self, new_key_id, new_api_key, user_id, name, days=30, only_chat=False
|
||||
) -> Optional[ApiKey]:
|
||||
response = (
|
||||
self.db.table("api_keys")
|
||||
.insert(
|
||||
[
|
||||
{
|
||||
"key_id": str(new_key_id),
|
||||
"user_id": str(user_id),
|
||||
"api_key": str(new_api_key),
|
||||
"name": str(name),
|
||||
"days": int(days),
|
||||
"only_chat": bool(only_chat),
|
||||
"creation_time": datetime.utcnow().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
"is_active": True,
|
||||
}
|
||||
]
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
if len(response.data) == 0:
|
||||
return None
|
||||
return ApiKey(**response.data[0])
|
||||
|
||||
def delete_api_key(self, key_id: str, user_id: UUID):
|
||||
return (
|
||||
self.db.table("api_keys")
|
||||
.update(
|
||||
{
|
||||
"is_active": False,
|
||||
"deleted_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
)
|
||||
.match({"key_id": key_id, "user_id": user_id})
|
||||
.execute()
|
||||
)
|
||||
|
||||
def get_active_api_key(self, api_key: str):
|
||||
response = (
|
||||
self.db.table("api_keys")
|
||||
.select("api_key", "creation_time")
|
||||
.filter("api_key", "eq", api_key)
|
||||
.filter("is_active", "eq", str(True))
|
||||
.execute()
|
||||
)
|
||||
return response
|
||||
|
||||
def get_user_id_by_api_key(self, api_key: str):
|
||||
response = (
|
||||
self.db.table("api_keys")
|
||||
.select("user_id")
|
||||
.filter("api_key", "eq", api_key)
|
||||
.execute()
|
||||
)
|
||||
return response
|
||||
|
||||
def get_user_api_keys(self, user_id):
|
||||
response = (
|
||||
self.db.table("api_keys")
|
||||
.select("key_id, creation_time")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.filter("is_active", "eq", True)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
@ -1,61 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
|
||||
from quivr_api.modules.api_key.repository.api_keys import ApiKeys
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_service import UserService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
user_service = UserService()
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
repository: ApiKeysInterface
|
||||
|
||||
def __init__(self):
|
||||
self.repository = ApiKeys()
|
||||
|
||||
async def verify_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
) -> bool:
|
||||
try:
|
||||
# Use UTC time to avoid timezone issues
|
||||
current_date = datetime.utcnow().date()
|
||||
result = self.repository.get_active_api_key(api_key)
|
||||
|
||||
if result.data is not None and len(result.data) > 0:
|
||||
api_key_creation_date = datetime.strptime(
|
||||
result.data[0]["creation_time"], "%Y-%m-%dT%H:%M:%S"
|
||||
).date()
|
||||
|
||||
if api_key_creation_date.year == current_date.year:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying API key: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_from_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
) -> UserIdentity:
|
||||
user_id_data = self.repository.get_user_id_by_api_key(api_key)
|
||||
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
user_id = user_id_data.data[0]["user_id"]
|
||||
|
||||
# TODO: directly UserService instead
|
||||
email = user_service.get_user_email_by_user_id(user_id)
|
||||
|
||||
if email is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
return UserIdentity(email=email, id=user_id)
|
@ -1,6 +0,0 @@
|
||||
# noqa:
|
||||
from .assistant_routes import assistant_router
|
||||
|
||||
__all__ = [
|
||||
"assistant_router",
|
||||
]
|
@ -1,202 +0,0 @@
|
||||
import io
|
||||
from typing import Annotated, List
|
||||
from uuid import uuid4
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
import re
|
||||
|
||||
from quivr_api.celery_config import celery
|
||||
from quivr_api.modules.assistant.dto.inputs import FileInput
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.modules.assistant.controller.assistants_definition import (
|
||||
assistants,
|
||||
validate_assistant_input,
|
||||
)
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask, InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import AssistantOutput
|
||||
from quivr_api.modules.assistant.entity.assistant_entity import (
|
||||
AssistantSettings,
|
||||
)
|
||||
from quivr_api.modules.assistant.entity.task_entity import TaskMetadata
|
||||
from quivr_api.modules.assistant.services.tasks_service import TasksService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.upload.service.upload_file import (
|
||||
upload_file_storage,
|
||||
)
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
assistant_router = APIRouter()
|
||||
|
||||
|
||||
TasksServiceDep = Annotated[TasksService, Depends(get_service(TasksService))]
|
||||
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def get_assistants(
|
||||
request: Request,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> List[AssistantOutput]:
|
||||
logger.info("Getting assistants")
|
||||
|
||||
return assistants
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/tasks", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def get_tasks(
|
||||
request: Request,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
logger.info("Getting tasks")
|
||||
return await tasks_service.get_tasks_by_user_id(current_user.id)
|
||||
|
||||
|
||||
@assistant_router.post(
|
||||
"/assistants/task", dependencies=[Depends(AuthBearer())], tags=["Assistant"]
|
||||
)
|
||||
async def create_task(
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
request: Request,
|
||||
input: str = File(...),
|
||||
files: List[UploadFile] = None,
|
||||
):
|
||||
inputs = InputAssistant.model_validate_json(input)
|
||||
|
||||
assistant = next(
|
||||
(assistant for assistant in assistants if assistant.id == inputs.id), None
|
||||
)
|
||||
|
||||
if assistant is None:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
is_valid, validation_errors = validate_assistant_input(inputs, assistant)
|
||||
if not is_valid:
|
||||
for error in validation_errors:
|
||||
print(error)
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
else:
|
||||
print("Assistant input is valid.")
|
||||
notification_uuid = f"{assistant.name}-{str(uuid4())[:8]}"
|
||||
|
||||
# Process files dynamically
|
||||
for upload_file in files:
|
||||
# Sanitize the filename to remove spaces and special characters
|
||||
sanitized_filename = re.sub(r'[^\w\-_\.]', '_', upload_file.filename)
|
||||
upload_file.filename = sanitized_filename
|
||||
|
||||
file_name_path = f"{inputs.id}/{notification_uuid}/{sanitized_filename}"
|
||||
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
|
||||
try:
|
||||
await upload_file_storage(buff_reader, file_name_path)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in upload_route {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to upload file to storage. {e}"
|
||||
)
|
||||
logger.info(f"Files are: {files}")
|
||||
|
||||
# Sanitize the filename in input
|
||||
if inputs.inputs.files:
|
||||
inputs.inputs.files = [
|
||||
FileInput(
|
||||
value=re.sub(r'[^\w\-_\.]', '_', file.value),
|
||||
key=file.key
|
||||
)
|
||||
for file in inputs.inputs.files
|
||||
]
|
||||
|
||||
task = CreateTask(
|
||||
assistant_id=inputs.id,
|
||||
assistant_name=assistant.name,
|
||||
pretty_id=notification_uuid,
|
||||
settings=inputs.model_dump(mode="json"),
|
||||
task_metadata=TaskMetadata(
|
||||
input_files=[file.filename for file in files]
|
||||
).model_dump(mode="json")
|
||||
if files
|
||||
else None, # type: ignore
|
||||
)
|
||||
|
||||
task_created = await tasks_service.create_task(task, current_user.id)
|
||||
|
||||
celery.send_task(
|
||||
"process_assistant_task",
|
||||
kwargs={
|
||||
"assistant_id": inputs.id,
|
||||
"notification_uuid": notification_uuid,
|
||||
"task_id": task_created.id,
|
||||
"user_id": str(current_user.id),
|
||||
},
|
||||
)
|
||||
return task_created
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/task/{task_id}",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def get_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
return await tasks_service.get_task_by_id(task_id, current_user.id) # type: ignore
|
||||
|
||||
|
||||
@assistant_router.delete(
|
||||
"/assistants/task/{task_id}",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def delete_task(
|
||||
request: Request,
|
||||
task_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
return await tasks_service.delete_task(task_id, current_user.id)
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/task/{task_id}/download",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
)
|
||||
async def get_download_link_task(
|
||||
request: Request,
|
||||
task_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
tasks_service: TasksServiceDep,
|
||||
):
|
||||
return await tasks_service.get_download_link_task(task_id, current_user.id)
|
||||
|
||||
|
||||
@assistant_router.get(
|
||||
"/assistants/{assistant_id}/config",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Assistant"],
|
||||
response_model=AssistantSettings,
|
||||
summary="Retrieve assistant configuration",
|
||||
description="Get the settings and file requirements for the specified assistant.",
|
||||
)
|
||||
async def get_assistant_config(
|
||||
assistant_id: int,
|
||||
current_user: UserIdentityDep,
|
||||
):
|
||||
assistant = next(
|
||||
(assistant for assistant in assistants if assistant.id == assistant_id), None
|
||||
)
|
||||
if assistant is None:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return assistant.settings
|
@ -1,251 +0,0 @@
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
ConditionalInput,
|
||||
InputBoolean,
|
||||
InputFile,
|
||||
Inputs,
|
||||
InputSelectText,
|
||||
Pricing,
|
||||
)
|
||||
|
||||
|
||||
def validate_assistant_input(
|
||||
assistant_input: InputAssistant, assistant_output: AssistantOutput
|
||||
):
|
||||
errors = []
|
||||
|
||||
# Validate files
|
||||
if assistant_output.inputs.files:
|
||||
required_files = [
|
||||
file for file in assistant_output.inputs.files if file.required
|
||||
]
|
||||
input_files = {
|
||||
file_input.key for file_input in (assistant_input.inputs.files or [])
|
||||
}
|
||||
for req_file in required_files:
|
||||
if req_file.key not in input_files:
|
||||
errors.append(f"Missing required file input: {req_file.key}")
|
||||
|
||||
# Validate URLs
|
||||
if assistant_output.inputs.urls:
|
||||
required_urls = [url for url in assistant_output.inputs.urls if url.required]
|
||||
input_urls = {
|
||||
url_input.key for url_input in (assistant_input.inputs.urls or [])
|
||||
}
|
||||
for req_url in required_urls:
|
||||
if req_url.key not in input_urls:
|
||||
errors.append(f"Missing required URL input: {req_url.key}")
|
||||
|
||||
# Validate texts
|
||||
if assistant_output.inputs.texts:
|
||||
required_texts = [
|
||||
text for text in assistant_output.inputs.texts if text.required
|
||||
]
|
||||
input_texts = {
|
||||
text_input.key for text_input in (assistant_input.inputs.texts or [])
|
||||
}
|
||||
for req_text in required_texts:
|
||||
if req_text.key not in input_texts:
|
||||
errors.append(f"Missing required text input: {req_text.key}")
|
||||
else:
|
||||
# Validate regex if applicable
|
||||
req_text_val = next(
|
||||
(t for t in assistant_output.inputs.texts if t.key == req_text.key),
|
||||
None,
|
||||
)
|
||||
if req_text_val and req_text_val.validation_regex:
|
||||
import re
|
||||
|
||||
input_value = next(
|
||||
(
|
||||
t.value
|
||||
for t in assistant_input.inputs.texts
|
||||
if t.key == req_text.key
|
||||
),
|
||||
"",
|
||||
)
|
||||
if not re.match(req_text_val.validation_regex, input_value):
|
||||
errors.append(
|
||||
f"Text input '{req_text.key}' does not match the required format."
|
||||
)
|
||||
|
||||
# Validate booleans
|
||||
if assistant_output.inputs.booleans:
|
||||
required_booleans = [b for b in assistant_output.inputs.booleans if b.required]
|
||||
input_booleans = {
|
||||
b_input.key for b_input in (assistant_input.inputs.booleans or [])
|
||||
}
|
||||
for req_bool in required_booleans:
|
||||
if req_bool.key not in input_booleans:
|
||||
errors.append(f"Missing required boolean input: {req_bool.key}")
|
||||
|
||||
# Validate numbers
|
||||
if assistant_output.inputs.numbers:
|
||||
required_numbers = [n for n in assistant_output.inputs.numbers if n.required]
|
||||
input_numbers = {
|
||||
n_input.key for n_input in (assistant_input.inputs.numbers or [])
|
||||
}
|
||||
for req_number in required_numbers:
|
||||
if req_number.key not in input_numbers:
|
||||
errors.append(f"Missing required number input: {req_number.key}")
|
||||
else:
|
||||
# Validate min and max
|
||||
input_value = next(
|
||||
(
|
||||
n.value
|
||||
for n in assistant_input.inputs.numbers
|
||||
if n.key == req_number.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if req_number.min is not None and input_value < req_number.min:
|
||||
errors.append(
|
||||
f"Number input '{req_number.key}' is below minimum value."
|
||||
)
|
||||
if req_number.max is not None and input_value > req_number.max:
|
||||
errors.append(
|
||||
f"Number input '{req_number.key}' exceeds maximum value."
|
||||
)
|
||||
|
||||
# Validate select_texts
|
||||
if assistant_output.inputs.select_texts:
|
||||
required_select_texts = [
|
||||
st for st in assistant_output.inputs.select_texts if st.required
|
||||
]
|
||||
input_select_texts = {
|
||||
st_input.key for st_input in (assistant_input.inputs.select_texts or [])
|
||||
}
|
||||
for req_select in required_select_texts:
|
||||
if req_select.key not in input_select_texts:
|
||||
errors.append(f"Missing required select text input: {req_select.key}")
|
||||
else:
|
||||
input_value = next(
|
||||
(
|
||||
st.value
|
||||
for st in assistant_input.inputs.select_texts
|
||||
if st.key == req_select.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if input_value not in req_select.options:
|
||||
errors.append(f"Invalid option for select text '{req_select.key}'.")
|
||||
|
||||
# Validate select_numbers
|
||||
if assistant_output.inputs.select_numbers:
|
||||
required_select_numbers = [
|
||||
sn for sn in assistant_output.inputs.select_numbers if sn.required
|
||||
]
|
||||
input_select_numbers = {
|
||||
sn_input.key for sn_input in (assistant_input.inputs.select_numbers or [])
|
||||
}
|
||||
for req_select in required_select_numbers:
|
||||
if req_select.key not in input_select_numbers:
|
||||
errors.append(f"Missing required select number input: {req_select.key}")
|
||||
else:
|
||||
input_value = next(
|
||||
(
|
||||
sn.value
|
||||
for sn in assistant_input.inputs.select_numbers
|
||||
if sn.key == req_select.key
|
||||
),
|
||||
None,
|
||||
)
|
||||
if input_value not in req_select.options:
|
||||
errors.append(
|
||||
f"Invalid option for select number '{req_select.key}'."
|
||||
)
|
||||
|
||||
# Validate brain input
|
||||
if assistant_output.inputs.brain and assistant_output.inputs.brain.required:
|
||||
if not assistant_input.inputs.brain or not assistant_input.inputs.brain.value:
|
||||
errors.append("Missing required brain input.")
|
||||
|
||||
if errors:
|
||||
return False, errors
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
assistant1 = AssistantOutput(
|
||||
id=1,
|
||||
name="Compliance Check",
|
||||
description="Allows analyzing the compliance of the information contained in documents against charter or regulatory requirements.",
|
||||
pricing=Pricing(),
|
||||
tags=["Disabled"],
|
||||
input_description="Input description",
|
||||
output_description="Output description",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(key="file_1", description="File description"),
|
||||
InputFile(key="file_2", description="File description"),
|
||||
],
|
||||
),
|
||||
icon_url="https://example.com/icon.png",
|
||||
)
|
||||
|
||||
assistant2 = AssistantOutput(
|
||||
id=2,
|
||||
name="Consistency Check",
|
||||
description="Ensures that the information in one document is replicated identically in another document.",
|
||||
pricing=Pricing(),
|
||||
tags=[],
|
||||
input_description="Input description",
|
||||
output_description="Output description",
|
||||
icon_url="https://example.com/icon.png",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(key="Document 1", description="File description"),
|
||||
InputFile(key="Document 2", description="File description"),
|
||||
],
|
||||
select_texts=[
|
||||
InputSelectText(
|
||||
key="DocumentsType",
|
||||
description="Select Documents Type",
|
||||
options=[
|
||||
"Cahier des charges VS Etiquettes",
|
||||
"Fiche Dev VS Cahier des charges",
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assistant3 = AssistantOutput(
|
||||
id=3,
|
||||
name="Difference Detection",
|
||||
description="Highlights differences between one document and another after modifications.",
|
||||
pricing=Pricing(),
|
||||
tags=[],
|
||||
input_description="Input description",
|
||||
output_description="Output description",
|
||||
icon_url="https://example.com/icon.png",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(key="Document 1", description="File description"),
|
||||
InputFile(key="Document 2", description="File description"),
|
||||
],
|
||||
booleans=[
|
||||
InputBoolean(
|
||||
key="Hard-to-Read Document?", description="Boolean description"
|
||||
),
|
||||
],
|
||||
select_texts=[
|
||||
InputSelectText(
|
||||
key="DocumentsType",
|
||||
description="Select Documents Type",
|
||||
options=["Etiquettes", "Cahier des charges"],
|
||||
),
|
||||
],
|
||||
conditional_inputs=[
|
||||
ConditionalInput(
|
||||
key="DocumentsType",
|
||||
conditional_key="Hard-to-Read Document?",
|
||||
condition="equals",
|
||||
value="Etiquettes",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assistants = [assistant1, assistant2, assistant3]
|
@ -1,75 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
|
||||
class CreateTask(BaseModel):
|
||||
pretty_id: str
|
||||
assistant_id: int
|
||||
assistant_name: str
|
||||
settings: dict
|
||||
task_metadata: Dict | None = None
|
||||
|
||||
|
||||
class BrainInput(BaseModel):
|
||||
value: Optional[UUID] = None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def empty_string_to_none(cls, values):
|
||||
for field, value in values.items():
|
||||
if value == "":
|
||||
values[field] = None
|
||||
return values
|
||||
|
||||
|
||||
class FileInput(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class UrlInput(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class TextInput(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class InputBoolean(BaseModel):
|
||||
key: str
|
||||
value: bool
|
||||
|
||||
|
||||
class InputNumber(BaseModel):
|
||||
key: str
|
||||
value: int
|
||||
|
||||
|
||||
class InputSelectText(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class InputSelectNumber(BaseModel):
|
||||
key: str
|
||||
value: int
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
files: Optional[List[FileInput]] = None
|
||||
urls: Optional[List[UrlInput]] = None
|
||||
texts: Optional[List[TextInput]] = None
|
||||
booleans: Optional[List[InputBoolean]] = None
|
||||
numbers: Optional[List[InputNumber]] = None
|
||||
select_texts: Optional[List[InputSelectText]] = None
|
||||
select_numbers: Optional[List[InputSelectNumber]] = None
|
||||
brain: Optional[BrainInput] = None
|
||||
|
||||
|
||||
class InputAssistant(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
inputs: Inputs
|
@ -1,105 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BrainInput(BaseModel):
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
type: str
|
||||
|
||||
|
||||
class InputFile(BaseModel):
|
||||
key: str
|
||||
allowed_extensions: Optional[List[str]] = ["pdf"]
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
|
||||
|
||||
class InputUrl(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
|
||||
|
||||
class InputText(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
validation_regex: Optional[str] = None
|
||||
|
||||
|
||||
class InputBoolean(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
|
||||
|
||||
class InputNumber(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
min: Optional[int] = None
|
||||
max: Optional[int] = None
|
||||
increment: Optional[int] = None
|
||||
default: Optional[int] = None
|
||||
|
||||
|
||||
class InputSelectText(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
options: List[str]
|
||||
default: Optional[str] = None
|
||||
|
||||
|
||||
class InputSelectNumber(BaseModel):
|
||||
key: str
|
||||
required: Optional[bool] = True
|
||||
description: str
|
||||
options: List[int]
|
||||
default: Optional[int] = None
|
||||
|
||||
|
||||
class ConditionalInput(BaseModel):
|
||||
"""
|
||||
Conditional input is a list of inputs that are conditional to the value of another input.
|
||||
key: The key of the input that is conditional.
|
||||
conditional_key: The key that determines if the input is shown.
|
||||
"""
|
||||
|
||||
key: str
|
||||
conditional_key: str
|
||||
condition: Optional[str] = (
|
||||
None # e.g. "equals", "contains", "starts_with", "ends_with", "regex", "in", "not_in", "is_empty", "is_not_empty"
|
||||
)
|
||||
value: Optional[str] = None
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
files: Optional[List[InputFile]] = None
|
||||
urls: Optional[List[InputUrl]] = None
|
||||
texts: Optional[List[InputText]] = None
|
||||
booleans: Optional[List[InputBoolean]] = None
|
||||
numbers: Optional[List[InputNumber]] = None
|
||||
select_texts: Optional[List[InputSelectText]] = None
|
||||
select_numbers: Optional[List[InputSelectNumber]] = None
|
||||
brain: Optional[BrainInput] = None
|
||||
conditional_inputs: Optional[List[ConditionalInput]] = None
|
||||
|
||||
|
||||
class Pricing(BaseModel):
|
||||
cost: int = 20
|
||||
description: str = "Credits per use"
|
||||
|
||||
|
||||
class AssistantOutput(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
pricing: Optional[Pricing] = Pricing()
|
||||
tags: Optional[List[str]] = []
|
||||
input_description: str
|
||||
output_description: str
|
||||
inputs: Inputs
|
||||
icon_url: Optional[str] = None
|
@ -1,33 +0,0 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AssistantFileRequirement(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
required: bool = True
|
||||
accepted_types: Optional[List[str]] = None # e.g., ['text/csv', 'application/json']
|
||||
|
||||
|
||||
class AssistantInput(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
type: str # e.g., 'boolean', 'uuid', 'string'
|
||||
required: bool = True
|
||||
regex: Optional[str] = None
|
||||
options: Optional[List[Any]] = None # For predefined choices
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class AssistantSettings(BaseModel):
|
||||
inputs: List[AssistantInput]
|
||||
files: Optional[List[AssistantFileRequirement]] = None
|
||||
|
||||
|
||||
class Assistant(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
settings: AssistantSettings
|
||||
required_files: Optional[List[str]] = None # List of required file names
|
@ -1,38 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text
|
||||
|
||||
|
||||
class TaskMetadata(BaseModel):
|
||||
input_files: Optional[List[str]] = None
|
||||
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
__tablename__ = "tasks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
),
|
||||
)
|
||||
assistant_id: int
|
||||
assistant_name: str
|
||||
pretty_id: str
|
||||
user_id: UUID
|
||||
status: str = Field(default="pending")
|
||||
creation_time: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
TIMESTAMP(timezone=False),
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
),
|
||||
)
|
||||
settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
answer: str | None = Field(default=None)
|
||||
task_metadata: Dict | None = Field(default_factory=dict, sa_column=Column(JSON))
|
@ -1,32 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
|
||||
|
||||
class TasksInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_task(self, task: CreateTask) -> Task:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_task(self, task_id: UUID, user_id: UUID) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tasks_by_user_id(self, user_id: UUID) -> List[Task]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_task(self, task_id: int, task: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
pass
|
@ -1,86 +0,0 @@
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, select
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
from quivr_api.modules.dependencies import BaseRepository
|
||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
)
|
||||
|
||||
|
||||
class TasksRepository(BaseRepository):
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(session)
|
||||
|
||||
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
|
||||
try:
|
||||
task_to_create = Task(
|
||||
assistant_id=task.assistant_id,
|
||||
assistant_name=task.assistant_name,
|
||||
pretty_id=task.pretty_id,
|
||||
user_id=user_id,
|
||||
settings=task.settings,
|
||||
task_metadata=task.task_metadata, # type: ignore
|
||||
)
|
||||
self.session.add(task_to_create)
|
||||
await self.session.commit()
|
||||
except exc.IntegrityError:
|
||||
await self.session.rollback()
|
||||
raise Exception()
|
||||
|
||||
await self.session.refresh(task_to_create)
|
||||
return task_to_create
|
||||
|
||||
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
return response.one()
|
||||
|
||||
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
|
||||
query = (
|
||||
select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc())
|
||||
)
|
||||
response = await self.session.exec(query)
|
||||
return response.all()
|
||||
|
||||
async def delete_task(self, task_id: int, user_id: UUID) -> None:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
if task:
|
||||
await self.session.delete(task)
|
||||
await self.session.commit()
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
async def update_task(self, task_id: int, task_updates: dict) -> None:
|
||||
query = select(Task).where(Task.id == task_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
if task:
|
||||
for key, value in task_updates.items():
|
||||
setattr(task, key, value)
|
||||
await self.session.commit()
|
||||
else:
|
||||
raise Exception("Task not found")
|
||||
|
||||
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
query = select(Task).where(Task.id == task_id, Task.user_id == user_id)
|
||||
response = await self.session.exec(query)
|
||||
task = response.one()
|
||||
|
||||
path = f"{task.assistant_id}/{task.pretty_id}/output.pdf"
|
||||
|
||||
try:
|
||||
signed_url = generate_file_signed_url(path)
|
||||
if signed_url and "signedURL" in signed_url:
|
||||
return signed_url["signedURL"]
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception:
|
||||
return "error"
|
@ -1,32 +0,0 @@
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.assistant.dto.inputs import CreateTask
|
||||
from quivr_api.modules.assistant.entity.task_entity import Task
|
||||
from quivr_api.modules.assistant.repository.tasks import TasksRepository
|
||||
from quivr_api.modules.dependencies import BaseService
|
||||
|
||||
|
||||
class TasksService(BaseService[TasksRepository]):
|
||||
repository_cls = TasksRepository
|
||||
|
||||
def __init__(self, repository: TasksRepository):
|
||||
self.repository = repository
|
||||
|
||||
async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
|
||||
return await self.repository.create_task(task, user_id)
|
||||
|
||||
async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
|
||||
return await self.repository.get_task_by_id(task_id, user_id)
|
||||
|
||||
async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
|
||||
return await self.repository.get_tasks_by_user_id(user_id)
|
||||
|
||||
async def delete_task(self, task_id: int, user_id: UUID) -> None:
|
||||
return await self.repository.delete_task(task_id, user_id)
|
||||
|
||||
async def update_task(self, task_id: int, task: dict) -> None:
|
||||
return await self.repository.update_task(task_id, task)
|
||||
|
||||
async def get_download_link_task(self, task_id: int, user_id: UUID) -> str:
|
||||
return await self.repository.get_download_link_task(task_id, user_id)
|
@ -1,113 +0,0 @@
|
||||
from typing import Any, Generic, Sequence, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import exc
|
||||
from sqlmodel import SQLModel, col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.modules.base_uuid_entity import BaseUUIDModel
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=BaseUUIDModel)
|
||||
CreateSchema = TypeVar("CreateSchema", bound=BaseModel)
|
||||
UpdateSchema = TypeVar("UpdateSchema", bound=BaseModel)
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
class BaseCRUDRepository(Generic[ModelType, CreateSchema, UpdateSchema]):
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession):
|
||||
"""
|
||||
Base repository for default CRUD operations
|
||||
"""
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
def get_db(self) -> AsyncSession:
|
||||
return self.session
|
||||
|
||||
async def get_by_id(
|
||||
self, *, id: UUID, db_session: AsyncSession
|
||||
) -> ModelType | None:
|
||||
query = select(self.model).where(self.model.id == id)
|
||||
response = await db_session.exec(query)
|
||||
return response.one()
|
||||
|
||||
async def get_by_ids(
|
||||
self,
|
||||
*,
|
||||
list_ids: list[UUID],
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> Sequence[ModelType] | None:
|
||||
db_session = db_session or self.session
|
||||
response = await db_session.exec(
|
||||
select(self.model).where(col(self.model.id).in_(list_ids))
|
||||
)
|
||||
return response.all()
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> Sequence[ModelType]:
|
||||
db_session = db_session or self.session
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
response = await db_session.exec(query)
|
||||
return response.all()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
entity: CreateSchema | ModelType,
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> ModelType:
|
||||
db_session = db_session or self.session
|
||||
db_obj = self.model.model_validate(entity) # type: ignore
|
||||
|
||||
try:
|
||||
db_session.add(db_obj)
|
||||
await db_session.commit()
|
||||
except exc.IntegrityError:
|
||||
await db_session.rollback()
|
||||
# TODO(@aminediro) : for now, build an exception system
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Resource already exists",
|
||||
)
|
||||
await db_session.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def update(
|
||||
self,
|
||||
*,
|
||||
obj_current: ModelType,
|
||||
obj_new: UpdateSchema | dict[str, Any] | ModelType,
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> ModelType:
|
||||
db_session = db_session or self.session
|
||||
|
||||
if isinstance(obj_new, dict):
|
||||
update_data = obj_new
|
||||
else:
|
||||
update_data = obj_new.dict(
|
||||
exclude_unset=True
|
||||
) # This tells Pydantic to not include the values that were not sent
|
||||
for field in update_data:
|
||||
setattr(obj_current, field, update_data[field])
|
||||
|
||||
db_session.add(obj_current)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(obj_current)
|
||||
return obj_current
|
||||
|
||||
async def remove(
|
||||
self, *, id: UUID | str, db_session: AsyncSession | None = None
|
||||
) -> ModelType:
|
||||
db_session = db_session or self.session
|
||||
response = await db_session.exec(select(self.model).where(self.model.id == id))
|
||||
obj = response.one()
|
||||
await db_session.delete(obj)
|
||||
await db_session.commit()
|
||||
return obj
|
@ -1,11 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class BaseUUIDModel(SQLModel, table=True):
|
||||
id: UUID = Field(
|
||||
primary_key=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
@ -1,5 +0,0 @@
|
||||
from .brain_routes import brain_router
|
||||
|
||||
__all__ = [
|
||||
"brain_router",
|
||||
]
|
@ -1,208 +0,0 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.modules.brain.dto.inputs import (
|
||||
BrainQuestionRequest,
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import (
|
||||
BrainType,
|
||||
MinimalUserBrainEntity,
|
||||
RoleEnum,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.integration_brain import (
|
||||
IntegrationDescriptionEntity,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
has_brain_authorization,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.brain.service.brain_user_service import BrainUserService
|
||||
from quivr_api.modules.brain.service.get_question_context_from_brain import (
|
||||
get_question_context_from_brain,
|
||||
)
|
||||
from quivr_api.modules.brain.service.integration_brain_service import (
|
||||
IntegrationBrainDescriptionService,
|
||||
)
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.prompt.service.prompt_service import PromptService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
from quivr_api.utils.telemetry import maybe_send_telemetry
|
||||
from quivr_api.utils.uuid_generator import generate_uuid_from_string
|
||||
|
||||
logger = get_logger(__name__)
|
||||
brain_router = APIRouter()
|
||||
|
||||
prompt_service = PromptService()
|
||||
brain_service = BrainService()
|
||||
brain_user_service = BrainUserService()
|
||||
integration_brain_description_service = IntegrationBrainDescriptionService()
|
||||
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
|
||||
|
||||
|
||||
@brain_router.get(
|
||||
"/brains/integrations/",
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
)
|
||||
async def get_integration_brain_description() -> list[IntegrationDescriptionEntity]:
|
||||
"""Retrieve the integration brain description."""
|
||||
# TODO: Deprecated, remove this endpoint
|
||||
return []
|
||||
|
||||
|
||||
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def retrieve_all_brains_for_user(
|
||||
model_service: ModelServiceDep,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""Retrieve all brains for the current user."""
|
||||
brains = brain_user_service.get_user_brains(current_user.id)
|
||||
models = await model_service.get_models()
|
||||
default_model = await model_service.get_default_model()
|
||||
|
||||
for brain in brains:
|
||||
# find the brain.model in models and set the brain.price to the model.price
|
||||
found = False
|
||||
if brain.model:
|
||||
for model in models:
|
||||
if model.name == brain.model:
|
||||
brain.price = model.price
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
brain.price = default_model.price
|
||||
|
||||
for model in models:
|
||||
brains.append(
|
||||
MinimalUserBrainEntity(
|
||||
id=generate_uuid_from_string(model.name),
|
||||
status="private",
|
||||
brain_type=BrainType.model,
|
||||
name=model.name,
|
||||
rights=RoleEnum.Viewer,
|
||||
model=True,
|
||||
price=model.price,
|
||||
max_input=model.max_input,
|
||||
max_output=model.max_output,
|
||||
display_name=model.display_name,
|
||||
image_url=model.image_url,
|
||||
description=model.description,
|
||||
integration_logo_url="model.integration_id",
|
||||
max_files=0,
|
||||
)
|
||||
)
|
||||
|
||||
return {"brains": brains}
|
||||
|
||||
|
||||
@brain_router.get(
|
||||
"/brains/{brain_id}/",
|
||||
dependencies=[
|
||||
Depends(AuthBearer()),
|
||||
Depends(
|
||||
has_brain_authorization(
|
||||
required_roles=[RoleEnum.Owner, RoleEnum.Editor, RoleEnum.Viewer]
|
||||
)
|
||||
),
|
||||
],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def retrieve_brain_by_id(
|
||||
brain_id: UUID,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""Retrieve details of a specific brain by its ID."""
|
||||
brain_details = brain_service.get_brain_details(brain_id, current_user.id)
|
||||
if brain_details is None:
|
||||
raise HTTPException(status_code=404, detail="Brain details not found")
|
||||
return brain_details
|
||||
|
||||
|
||||
@brain_router.post("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def create_new_brain(
|
||||
brain: CreateBrainProperties,
|
||||
request: Request,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new brain for the user."""
|
||||
user_brains = brain_user_service.get_user_brains(current_user.id)
|
||||
user_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
if len(user_brains) >= user_settings.get("max_brains", 5):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Maximum number of brains reached ({user_settings.get('max_brains', 5)}).",
|
||||
)
|
||||
maybe_send_telemetry("create_brain", {"brain_name": brain.name}, request)
|
||||
new_brain = brain_service.create_brain(
|
||||
brain=brain,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
brain_user_service.create_brain_user(
|
||||
user_id=current_user.id,
|
||||
brain_id=new_brain.brain_id,
|
||||
rights=RoleEnum.Owner,
|
||||
is_default_brain=True,
|
||||
)
|
||||
|
||||
return {"id": new_brain.brain_id, "name": brain.name, "rights": "Owner"}
|
||||
|
||||
|
||||
@brain_router.put(
|
||||
"/brains/{brain_id}/",
|
||||
dependencies=[
|
||||
Depends(AuthBearer()),
|
||||
Depends(has_brain_authorization([RoleEnum.Editor, RoleEnum.Owner])),
|
||||
],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def update_existing_brain(
|
||||
brain_id: UUID,
|
||||
brain_update_data: BrainUpdatableProperties,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""Update an existing brain's configuration."""
|
||||
existing_brain = brain_service.get_brain_details(brain_id, current_user.id)
|
||||
if existing_brain is None:
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
if brain_update_data.prompt_id is None and existing_brain.prompt_id:
|
||||
prompt = prompt_service.get_prompt_by_id(existing_brain.prompt_id)
|
||||
if prompt and prompt.status == "private":
|
||||
prompt_service.delete_prompt_by_id(existing_brain.prompt_id)
|
||||
|
||||
return {"message": f"Prompt {brain_id} has been updated."}
|
||||
|
||||
elif brain_update_data.status == "private" and existing_brain.status == "public":
|
||||
brain_user_service.delete_brain_users(brain_id)
|
||||
return {"message": f"Brain {brain_id} has been deleted."}
|
||||
|
||||
else:
|
||||
brain_service.update_brain_by_id(brain_id, brain_update_data)
|
||||
|
||||
return {"message": f"Brain {brain_id} has been updated."}
|
||||
|
||||
|
||||
@brain_router.post(
|
||||
"/brains/{brain_id}/documents",
|
||||
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def get_question_context_for_brain(
|
||||
brain_id: UUID, question: BrainQuestionRequest
|
||||
):
|
||||
# TODO: Move this endpoint to AnswerGenerator service
|
||||
"""Retrieve the question context from a specific brain."""
|
||||
context = get_question_context_from_brain(brain_id, question.question)
|
||||
return {"docs": context}
|
@ -1,71 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainType
|
||||
from quivr_api.modules.brain.entity.integration_brain import IntegrationType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CreateIntegrationBrain(BaseModel, extra="ignore"):
|
||||
integration_name: str
|
||||
integration_logo_url: str
|
||||
connection_settings: dict
|
||||
integration_type: IntegrationType
|
||||
description: str
|
||||
max_files: int
|
||||
|
||||
|
||||
class BrainIntegrationSettings(BaseModel, extra="ignore"):
|
||||
integration_id: str
|
||||
settings: dict
|
||||
|
||||
|
||||
class BrainIntegrationUpdateSettings(BaseModel, extra="ignore"):
|
||||
settings: dict
|
||||
|
||||
|
||||
class CreateBrainProperties(BaseModel, extra="ignore"):
|
||||
name: Optional[str] = "Default brain"
|
||||
description: str = "This is a description"
|
||||
status: Optional[str] = "private"
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 2000
|
||||
prompt_id: Optional[UUID] = None
|
||||
brain_type: Optional[BrainType] = BrainType.doc
|
||||
integration: Optional[BrainIntegrationSettings] = None
|
||||
snippet_color: Optional[str] = None
|
||||
snippet_emoji: Optional[str] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
brain_dict = super().dict(*args, **kwargs)
|
||||
if brain_dict.get("prompt_id"):
|
||||
brain_dict["prompt_id"] = str(brain_dict.get("prompt_id"))
|
||||
return brain_dict
|
||||
|
||||
|
||||
class BrainUpdatableProperties(BaseModel, extra="ignore"):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
model: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
prompt_id: Optional[UUID] = None
|
||||
integration: Optional[BrainIntegrationUpdateSettings] = None
|
||||
snippet_color: Optional[str] = None
|
||||
snippet_emoji: Optional[str] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
brain_dict = super().dict(*args, **kwargs)
|
||||
if brain_dict.get("prompt_id"):
|
||||
brain_dict["prompt_id"] = str(brain_dict.get("prompt_id"))
|
||||
return brain_dict
|
||||
|
||||
|
||||
class BrainQuestionRequest(BaseModel):
|
||||
question: str
|
@ -1,133 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from quivr_core.config import BrainConfig
|
||||
from sqlalchemy.dialects.postgresql import ENUM as PGEnum
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import TIMESTAMP, Column, Field, Relationship, SQLModel, text
|
||||
from sqlmodel import UUID as PGUUID
|
||||
|
||||
from quivr_api.modules.brain.entity.integration_brain import (
|
||||
IntegrationDescriptionEntity,
|
||||
IntegrationEntity,
|
||||
)
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||
|
||||
# from sqlmodel import Enum as PGEnum
|
||||
from quivr_api.modules.prompt.entity.prompt import Prompt
|
||||
|
||||
|
||||
class BrainType(str, Enum):
|
||||
doc = "doc"
|
||||
api = "api"
|
||||
composite = "composite"
|
||||
integration = "integration"
|
||||
model = "model"
|
||||
|
||||
|
||||
class Brain(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "brains" # type: ignore
|
||||
|
||||
brain_id: UUID | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
PGUUID,
|
||||
server_default=text("uuid_generate_v4()"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
name: str
|
||||
description: str
|
||||
status: str | None = None
|
||||
model: str | None = None
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
last_update: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
TIMESTAMP(timezone=False),
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
),
|
||||
)
|
||||
brain_type: BrainType | None = Field(
|
||||
sa_column=Column(
|
||||
PGEnum(BrainType, name="brain_type_enum", create_type=False),
|
||||
default=BrainType.integration,
|
||||
),
|
||||
)
|
||||
brain_chat_history: List["ChatHistory"] = Relationship( # type: ignore # noqa: F821
|
||||
back_populates="brain", sa_relationship_kwargs={"lazy": "select"}
|
||||
)
|
||||
prompt_id: UUID | None = Field(default=None, foreign_key="prompts.id")
|
||||
prompt: Prompt | None = Relationship( # noqa: F821
|
||||
back_populates="brain", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
knowledges: List[KnowledgeDB] = Relationship(
|
||||
back_populates="brains", link_model=KnowledgeBrain
|
||||
)
|
||||
|
||||
# TODO : add
|
||||
# "meaning" "public"."vector",
|
||||
# "tags" "public"."tags"[]
|
||||
|
||||
|
||||
class BrainEntity(BrainConfig):
|
||||
last_update: datetime | None = None
|
||||
brain_type: BrainType | None = None
|
||||
description: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
meaning: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
model: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
prompt_id: Optional[UUID] = None
|
||||
integration: Optional[IntegrationEntity] = None
|
||||
integration_description: Optional[IntegrationDescriptionEntity] = None
|
||||
snippet_emoji: Optional[str] = None
|
||||
snippet_color: Optional[str] = None
|
||||
|
||||
def dict(self, **kwargs):
|
||||
data = super().dict(
|
||||
**kwargs,
|
||||
)
|
||||
data["id"] = self.id
|
||||
return data
|
||||
|
||||
|
||||
class RoleEnum(str, Enum):
|
||||
Viewer = "Viewer"
|
||||
Editor = "Editor"
|
||||
Owner = "Owner"
|
||||
|
||||
|
||||
class BrainUser(BaseModel):
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
rights: RoleEnum
|
||||
default_brain: bool = False
|
||||
|
||||
|
||||
class MinimalUserBrainEntity(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brain_model: Optional[str] = None
|
||||
rights: RoleEnum
|
||||
status: str
|
||||
brain_type: BrainType
|
||||
description: str
|
||||
integration_logo_url: str
|
||||
max_files: int
|
||||
price: Optional[int] = None
|
||||
max_input: Optional[int] = None
|
||||
max_output: Optional[int] = None
|
||||
display_name: Optional[str] = None
|
||||
image_url: Optional[str] = None
|
||||
model: bool = False
|
||||
snippet_color: Optional[str] = None
|
||||
snippet_emoji: Optional[str] = None
|
@ -1,46 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class IntegrationType(str, Enum):
|
||||
CUSTOM = "custom"
|
||||
SYNC = "sync"
|
||||
DOC = "doc"
|
||||
|
||||
|
||||
class IntegrationBrainTag(str, Enum):
|
||||
NEW = "new"
|
||||
RECOMMENDED = "recommended"
|
||||
MOST_POPULAR = "most_popular"
|
||||
PREMIUM = "premium"
|
||||
COMING_SOON = "coming_soon"
|
||||
COMMUNITY = "community"
|
||||
DEPRECATED = "deprecated"
|
||||
|
||||
|
||||
class IntegrationDescriptionEntity(BaseModel):
|
||||
id: UUID
|
||||
integration_name: str
|
||||
integration_logo_url: Optional[str] = None
|
||||
connection_settings: Optional[dict] = None
|
||||
integration_type: IntegrationType
|
||||
tags: Optional[list[IntegrationBrainTag]] = []
|
||||
information: Optional[str] = None
|
||||
description: str
|
||||
max_files: int
|
||||
allow_model_change: bool
|
||||
integration_display_name: str
|
||||
onboarding_brain: bool
|
||||
|
||||
|
||||
class IntegrationEntity(BaseModel):
|
||||
id: int
|
||||
user_id: str
|
||||
brain_id: str
|
||||
integration_id: str
|
||||
settings: Optional[dict] = None
|
||||
credentials: Optional[dict] = None
|
||||
last_synced: str
|
@ -1,146 +0,0 @@
|
||||
import json
|
||||
from typing import AsyncIterable
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BigBrain(KnowledgeBrainQA):
|
||||
"""
|
||||
The BigBrain class integrates advanced conversational retrieval and language model chains
|
||||
to provide comprehensive and context-aware responses to user queries.
|
||||
|
||||
It leverages a combination of document retrieval, question condensation, and document-based
|
||||
question answering to generate responses that are informed by a wide range of knowledge sources.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the BigBrain class with specific configurations.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_chain(self):
|
||||
"""
|
||||
Constructs and returns the conversational QA chain used by BigBrain.
|
||||
|
||||
Returns:
|
||||
A ConversationalRetrievalChain instance.
|
||||
"""
|
||||
system_template = """Combine these summaries in a way that makes sense and answer the user's question.
|
||||
Use markdown or any other techniques to display the content in a nice and aerated way. Answer in the language of the question.
|
||||
Here are user instructions on how to respond: {custom_personality}
|
||||
______________________
|
||||
{summaries}"""
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(system_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_COMBINE_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
### Question prompt
|
||||
question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
||||
Return any relevant text verbatim. Return the answer in the same language as the question. If the answer is not in the text, just say nothing in the same language as the question.
|
||||
{context}
|
||||
Question: {question}
|
||||
Relevant text, if any, else say Nothing:"""
|
||||
QUESTION_PROMPT = PromptTemplate(
|
||||
template=question_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
### Condense Question Prompt
|
||||
|
||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question in exactly the same language as the original question.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question in same language as question:"""
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
|
||||
api_base = None
|
||||
if self.brain_settings.ollama_api_base_url and self.model.startswith("ollama"):
|
||||
api_base = self.brain_settings.ollama_api_base_url
|
||||
|
||||
llm = ChatLiteLLM(
|
||||
temperature=0,
|
||||
model=self.model,
|
||||
api_base=api_base,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
retriever_doc = self.knowledge_qa.get_retriever()
|
||||
|
||||
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
|
||||
doc_chain = load_qa_chain(
|
||||
llm,
|
||||
chain_type="map_reduce",
|
||||
question_prompt=QUESTION_PROMPT,
|
||||
combine_prompt=CHAT_COMBINE_PROMPT,
|
||||
)
|
||||
|
||||
chain = ConversationalRetrievalChain(
|
||||
retriever=retriever_doc,
|
||||
question_generator=question_generator,
|
||||
combine_docs_chain=doc_chain,
|
||||
)
|
||||
|
||||
return chain
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
"""
|
||||
Generates a stream of responses for a given question in real-time.
|
||||
|
||||
Args:
|
||||
chat_id (UUID): The unique identifier for the chat session.
|
||||
question (ChatQuestion): The question object containing the user's query.
|
||||
save_answer (bool): Flag indicating whether to save the answer to the chat history.
|
||||
|
||||
Returns:
|
||||
An asynchronous iterable of response strings.
|
||||
"""
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
)
|
||||
response_tokens = []
|
||||
|
||||
async for chunk in conversational_qa_chain.astream(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": (
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
}
|
||||
):
|
||||
if "answer" in chunk:
|
||||
response_tokens.append(chunk["answer"])
|
||||
streamed_chat_history.assistant = chunk["answer"]
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
@ -1,101 +0,0 @@
|
||||
import json
|
||||
from typing import AsyncIterable
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
||||
|
||||
class ClaudeBrain(KnowledgeBrainQA):
|
||||
"""
|
||||
ClaudeBrain integrates with Claude model to provide conversational AI capabilities.
|
||||
It leverages the Claude model for generating responses based on the provided context.
|
||||
|
||||
Attributes:
|
||||
**kwargs: Arbitrary keyword arguments for KnowledgeBrainQA initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the ClaudeBrain with the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
"""
|
||||
Calculates the pricing for using the ClaudeBrain.
|
||||
|
||||
Returns:
|
||||
int: The pricing value.
|
||||
"""
|
||||
return 3
|
||||
|
||||
def get_chain(self):
|
||||
"""
|
||||
Constructs and returns the conversational chain for ClaudeBrain.
|
||||
|
||||
Returns:
|
||||
A conversational chain object.
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are Claude powered by Quivr. You are an assistant. {custom_personality}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
chain = prompt | ChatLiteLLM(
|
||||
model="claude-3-haiku-20240307", max_tokens=self.max_tokens
|
||||
)
|
||||
|
||||
return chain
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
"""
|
||||
Generates a stream of responses for the given question.
|
||||
|
||||
Args:
|
||||
chat_id (UUID): The chat session ID.
|
||||
question (ChatQuestion): The question object.
|
||||
save_answer (bool): Whether to save the answer.
|
||||
|
||||
Yields:
|
||||
AsyncIterable: A stream of response strings.
|
||||
"""
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
)
|
||||
response_tokens = []
|
||||
|
||||
async for chunk in conversational_qa_chain.astream(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": (
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
}
|
||||
):
|
||||
response_tokens.append(chunk.content)
|
||||
streamed_chat_history.assistant = chunk.content
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
@ -1,283 +0,0 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Annotated, AsyncIterable, List, Optional, Sequence, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.messages import BaseMessage, ToolMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.prebuilt import ToolExecutor, ToolInvocation
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.tools import (
|
||||
EmailSenderTool,
|
||||
ImageGeneratorTool,
|
||||
URLReaderTool,
|
||||
WebSearchTool,
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[Sequence[BaseMessage], operator.add]
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
chat_service = get_service(ChatService)()
|
||||
|
||||
|
||||
class GPT4Brain(KnowledgeBrainQA):
|
||||
"""
|
||||
GPT4Brain integrates with GPT-4 to provide real-time answers and supports various tools to enhance its capabilities.
|
||||
|
||||
Available Tools:
|
||||
- WebSearchTool: Performs web searches to find relevant information.
|
||||
- ImageGeneratorTool: Generates images based on textual descriptions.
|
||||
- URLReaderTool: Reads and summarizes content from URLs.
|
||||
- EmailSenderTool: Sends emails with specified content.
|
||||
|
||||
Use Cases:
|
||||
- WebSearchTool can be used to find the latest news articles on a specific topic or to gather information from various websites.
|
||||
- ImageGeneratorTool is useful for creating visual content based on textual prompts, such as generating a company logo based on a description.
|
||||
- URLReaderTool can be used to summarize articles or web pages, making it easier to quickly understand the content without reading the entire text.
|
||||
- EmailSenderTool enables automated email sending, such as sending a summary of a meeting's minutes to all participants.
|
||||
"""
|
||||
|
||||
tools: Optional[List[BaseTool]] = None
|
||||
tool_executor: Optional[ToolExecutor] = None
|
||||
function_model: ChatOpenAI = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
self.tools = [
|
||||
WebSearchTool(),
|
||||
ImageGeneratorTool(),
|
||||
URLReaderTool(),
|
||||
EmailSenderTool(user_email=self.user_email),
|
||||
]
|
||||
self.tool_executor = ToolExecutor(tools=self.tools)
|
||||
|
||||
def calculate_pricing(self):
|
||||
return 3
|
||||
|
||||
def should_continue(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
# Make sure there is a previous message
|
||||
|
||||
if last_message.tool_calls:
|
||||
name = last_message.tool_calls[0]["name"]
|
||||
if name == "image-generator":
|
||||
return "final"
|
||||
# If there is no function call, then we finish
|
||||
if not last_message.tool_calls:
|
||||
return "end"
|
||||
# Otherwise if there is, we check if it's suppose to return direct
|
||||
else:
|
||||
return "continue"
|
||||
|
||||
# Define the function that calls the model
|
||||
def call_model(self, state):
|
||||
messages = state["messages"]
|
||||
response = self.function_model.invoke(messages)
|
||||
# We return a list, because this will get added to the existing list
|
||||
return {"messages": [response]}
|
||||
|
||||
# Define the function to execute tools
|
||||
def call_tool(self, state):
|
||||
messages = state["messages"]
|
||||
# Based on the continue condition
|
||||
# we know the last message involves a function call
|
||||
last_message = messages[-1]
|
||||
# We construct an ToolInvocation from the function_call
|
||||
tool_call = last_message.tool_calls[0]
|
||||
tool_name = tool_call["name"]
|
||||
arguments = tool_call["args"]
|
||||
|
||||
action = ToolInvocation(
|
||||
tool=tool_call["name"],
|
||||
tool_input=tool_call["args"],
|
||||
)
|
||||
# We call the tool_executor and get back a response
|
||||
response = self.tool_executor.invoke(action)
|
||||
# We use the response to create a FunctionMessage
|
||||
function_message = ToolMessage(
|
||||
content=str(response), name=action.tool, tool_call_id=tool_call["id"]
|
||||
)
|
||||
# We return a list, because this will get added to the existing list
|
||||
return {"messages": [function_message]}
|
||||
|
||||
def create_graph(self):
|
||||
# Define a new graph
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Define the two nodes we will cycle between
|
||||
workflow.add_node("agent", self.call_model)
|
||||
workflow.add_node("action", self.call_tool)
|
||||
workflow.add_node("final", self.call_tool)
|
||||
|
||||
# Set the entrypoint as `agent`
|
||||
# This means that this node is the first one called
|
||||
workflow.set_entry_point("agent")
|
||||
|
||||
# We now add a conditional edge
|
||||
workflow.add_conditional_edges(
|
||||
# First, we define the start node. We use `agent`.
|
||||
# This means these are the edges taken after the `agent` node is called.
|
||||
"agent",
|
||||
# Next, we pass in the function that will determine which node is called next.
|
||||
self.should_continue,
|
||||
# Finally we pass in a mapping.
|
||||
# The keys are strings, and the values are other nodes.
|
||||
# END is a special node marking that the graph should finish.
|
||||
# What will happen is we will call `should_continue`, and then the output of that
|
||||
# will be matched against the keys in this mapping.
|
||||
# Based on which one it matches, that node will then be called.
|
||||
{
|
||||
# If `tools`, then we call the tool node.
|
||||
"continue": "action",
|
||||
# Final call
|
||||
"final": "final",
|
||||
# Otherwise we finish.
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
|
||||
# We now add a normal edge from `tools` to `agent`.
|
||||
# This means that after `tools` is called, `agent` node is called next.
|
||||
workflow.add_edge("action", "agent")
|
||||
workflow.add_edge("final", END)
|
||||
|
||||
# Finally, we compile it!
|
||||
# This compiles it into a LangChain Runnable,
|
||||
# meaning you can use it as you would any other runnable
|
||||
app = workflow.compile()
|
||||
return app
|
||||
|
||||
def get_chain(self):
|
||||
self.function_model = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True)
|
||||
|
||||
self.function_model = self.function_model.bind_tools(self.tools)
|
||||
|
||||
graph = self.create_graph()
|
||||
|
||||
return graph
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
)
|
||||
filtered_history = self.filter_history(transformed_history, 40, 2000)
|
||||
response_tokens = []
|
||||
config = {"metadata": {"conversation_id": str(chat_id)}}
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are GPT-4 powered by Quivr. You are an assistant. {custom_personality}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
prompt_formated = prompt.format_messages(
|
||||
chat_history=filtered_history,
|
||||
question=question.question,
|
||||
custom_personality=(
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
)
|
||||
|
||||
async for event in conversational_qa_chain.astream_events(
|
||||
{"messages": prompt_formated},
|
||||
config=config,
|
||||
version="v1",
|
||||
):
|
||||
kind = event["event"]
|
||||
if kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
# Empty content in the context of OpenAI or Anthropic usually means
|
||||
# that the model is asking for a tool to be invoked.
|
||||
# So we only print non-empty content
|
||||
response_tokens.append(content)
|
||||
streamed_chat_history.assistant = content
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
elif kind == "on_tool_start":
|
||||
print("--")
|
||||
print(
|
||||
f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}"
|
||||
)
|
||||
elif kind == "on_tool_end":
|
||||
print(f"Done tool: {event['name']}")
|
||||
print(f"Tool output was: {event['data'].get('output')}")
|
||||
print("--")
|
||||
elif kind == "on_chain_end":
|
||||
output = event["data"]["output"]
|
||||
final_output = [item for item in output if "final" in item]
|
||||
if final_output:
|
||||
if (
|
||||
final_output[0]["final"]["messages"][0].name
|
||||
== "image-generator"
|
||||
):
|
||||
final_message = final_output[0]["final"]["messages"][0].content
|
||||
response_tokens.append(final_message)
|
||||
streamed_chat_history.assistant = final_message
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, _ = self.initialize_streamed_chat_history(
|
||||
chat_id, question
|
||||
)
|
||||
filtered_history = self.filter_history(transformed_history, 40, 2000)
|
||||
config = {"metadata": {"conversation_id": str(chat_id)}}
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are GPT-4 powered by Quivr. You are an assistant. {custom_personality}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
prompt_formated = prompt.format_messages(
|
||||
chat_history=filtered_history,
|
||||
question=question.question,
|
||||
custom_personality=(
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
)
|
||||
model_response = conversational_qa_chain.invoke(
|
||||
{"messages": prompt_formated},
|
||||
config=config,
|
||||
)
|
||||
|
||||
answer = model_response["messages"][-1].content
|
||||
|
||||
return self.save_non_streaming_answer(
|
||||
chat_id=chat_id, question=question, answer=answer, metadata={}
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user