mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
Feat/update brain fields (#756)
* 🗃️ update and add fields in brains table * ✨ update endpoints for more brain attribute * ✨ new set as default brain endpoint * 🔥 remove update brain with file * ✏️ fix wrong auto imports * 🐛 fix max tokens for model in front * 🚑 post instead of put to set default brain * 🚑 update brain creation endpoint with new fields
This commit is contained in:
parent
046cc3fc1d
commit
e05f25b025
@ -2,11 +2,10 @@ from typing import Any, List, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from pydantic import BaseModel
|
|
||||||
from utils.vectors import get_unique_files_from_vector_ids
|
|
||||||
|
|
||||||
from models.settings import BrainRateLimiting, CommonsDep, common_dependencies
|
from models.settings import BrainRateLimiting, CommonsDep, common_dependencies
|
||||||
from models.users import User
|
from models.users import User
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.vectors import get_unique_files_from_vector_ids
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -14,10 +13,12 @@ logger = get_logger(__name__)
|
|||||||
class Brain(BaseModel):
|
class Brain(BaseModel):
|
||||||
id: Optional[UUID] = None
|
id: Optional[UUID] = None
|
||||||
name: Optional[str] = "Default brain"
|
name: Optional[str] = "Default brain"
|
||||||
status: Optional[str] = "public"
|
description: Optional[str] = "This is a description"
|
||||||
|
status: Optional[str] = "private"
|
||||||
model: Optional[str] = "gpt-3.5-turbo-0613"
|
model: Optional[str] = "gpt-3.5-turbo-0613"
|
||||||
temperature: Optional[float] = 0.0
|
temperature: Optional[float] = 0.0
|
||||||
max_tokens: Optional[int] = 256
|
max_tokens: Optional[int] = 256
|
||||||
|
openai_api_key: Optional[str] = None
|
||||||
files: List[Any] = []
|
files: List[Any] = []
|
||||||
max_brain_size = BrainRateLimiting().max_brain_size
|
max_brain_size = BrainRateLimiting().max_brain_size
|
||||||
|
|
||||||
@ -150,7 +151,20 @@ class Brain(BaseModel):
|
|||||||
def create_brain(self):
|
def create_brain(self):
|
||||||
commons = common_dependencies()
|
commons = common_dependencies()
|
||||||
response = (
|
response = (
|
||||||
commons["supabase"].table("brains").insert({"name": self.name}).execute()
|
commons["supabase"]
|
||||||
|
.table("brains")
|
||||||
|
.insert(
|
||||||
|
{
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"openai_api_key": self.openai_api_key,
|
||||||
|
"status": self.status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.id = response.data[0]["brain_id"]
|
self.id = response.data[0]["brain_id"]
|
||||||
@ -174,6 +188,18 @@ class Brain(BaseModel):
|
|||||||
|
|
||||||
return response.data
|
return response.data
|
||||||
|
|
||||||
|
def set_as_default_brain_for_user(self, user: User):
|
||||||
|
old_default_brain = get_default_user_brain(user)
|
||||||
|
|
||||||
|
if old_default_brain is not None:
|
||||||
|
self.commons["supabase"].table("brains_users").update(
|
||||||
|
{"default_brain": False}
|
||||||
|
).match({"brain_id": old_default_brain["id"], "user_id": user.id}).execute()
|
||||||
|
|
||||||
|
self.commons["supabase"].table("brains_users").update(
|
||||||
|
{"default_brain": True}
|
||||||
|
).match({"brain_id": self.id, "user_id": user.id}).execute()
|
||||||
|
|
||||||
def create_brain_vector(self, vector_id, file_sha1):
|
def create_brain_vector(self, vector_id, file_sha1):
|
||||||
response = (
|
response = (
|
||||||
self.commons["supabase"]
|
self.commons["supabase"]
|
||||||
@ -201,15 +227,17 @@ class Brain(BaseModel):
|
|||||||
return vectorsResponse.data
|
return vectorsResponse.data
|
||||||
|
|
||||||
def update_brain_fields(self):
|
def update_brain_fields(self):
|
||||||
self.commons["supabase"].table("brains").update({"name": self.name}).match(
|
self.commons["supabase"].table("brains").update(
|
||||||
{"brain_id": self.id}
|
{
|
||||||
).execute()
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
def update_brain_with_file(self, file_sha1: str):
|
"model": self.model,
|
||||||
# not used
|
"temperature": self.temperature,
|
||||||
vector_ids = self.get_vector_ids_from_file_sha1(file_sha1)
|
"max_tokens": self.max_tokens,
|
||||||
for vector_id in vector_ids:
|
"openai_api_key": self.openai_api_key,
|
||||||
self.create_brain_vector(vector_id, file_sha1)
|
"status": self.status,
|
||||||
|
}
|
||||||
|
).match({"brain_id": self.id}).execute()
|
||||||
|
|
||||||
def get_unique_brain_files(self):
|
def get_unique_brain_files(self):
|
||||||
"""
|
"""
|
||||||
|
@ -8,10 +8,9 @@ from models.brains import (
|
|||||||
get_default_user_brain,
|
get_default_user_brain,
|
||||||
get_default_user_brain_or_create_new,
|
get_default_user_brain_or_create_new,
|
||||||
)
|
)
|
||||||
from models.settings import BrainRateLimiting, common_dependencies
|
from models.settings import BrainRateLimiting
|
||||||
from models.users import User
|
from models.users import User
|
||||||
|
from routes.authorizations.brain_authorization import RoleEnum, has_brain_authorization
|
||||||
from routes.authorizations.brain_authorization import has_brain_authorization
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -102,8 +101,6 @@ async def create_brain_endpoint(
|
|||||||
In the brains table & in the brains_users table and put the creator user as 'Owner'
|
In the brains table & in the brains_users table and put the creator user as 'Owner'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
brain = Brain(name=brain.name) # pyright: ignore reportPrivateUsage=none
|
|
||||||
|
|
||||||
user_brains = brain.get_user_brains(current_user.id)
|
user_brains = brain.get_user_brains(current_user.id)
|
||||||
max_brain_per_user = BrainRateLimiting().max_brain_per_user
|
max_brain_per_user = BrainRateLimiting().max_brain_per_user
|
||||||
|
|
||||||
@ -142,7 +139,7 @@ async def create_brain_endpoint(
|
|||||||
Depends(
|
Depends(
|
||||||
AuthBearer(),
|
AuthBearer(),
|
||||||
),
|
),
|
||||||
Depends(has_brain_authorization()),
|
Depends(has_brain_authorization([RoleEnum.Editor, RoleEnum.Owner])),
|
||||||
],
|
],
|
||||||
tags=["Brain"],
|
tags=["Brain"],
|
||||||
)
|
)
|
||||||
@ -151,22 +148,35 @@ async def update_brain_endpoint(
|
|||||||
input_brain: Brain,
|
input_brain: Brain,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update an existing brain with new brain parameters/files.
|
Update an existing brain with new brain configuration
|
||||||
If the file is contained in Add file to brain :
|
"""
|
||||||
if given a fileName/ file sha1 / -> add all the vector Ids to the brains_vectors
|
input_brain.id = brain_id
|
||||||
Modify other brain fields:
|
print("brain", input_brain)
|
||||||
name, status, model, max_tokens, temperature
|
|
||||||
Return modified brain ? No need -> do an optimistic update
|
input_brain.update_brain_fields()
|
||||||
|
return {"message": f"Brain {brain_id} has been updated."}
|
||||||
|
|
||||||
|
|
||||||
|
# set as default brain
|
||||||
|
@brain_router.post(
|
||||||
|
"/brains/{brain_id}/default",
|
||||||
|
dependencies=[
|
||||||
|
Depends(
|
||||||
|
AuthBearer(),
|
||||||
|
),
|
||||||
|
Depends(has_brain_authorization()),
|
||||||
|
],
|
||||||
|
tags=["Brain"],
|
||||||
|
)
|
||||||
|
async def set_as_default_brain_endpoint(
|
||||||
|
brain_id: UUID,
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Set a brain as default for the current user.
|
||||||
"""
|
"""
|
||||||
commons = common_dependencies()
|
|
||||||
brain = Brain(id=brain_id)
|
brain = Brain(id=brain_id)
|
||||||
|
|
||||||
# Add new file to brain , il file_sha1 already exists in brains_vectors -> out (not now)
|
brain.set_as_default_brain_for_user(user)
|
||||||
if brain.file_sha1: # pyright: ignore reportPrivateUsage=none
|
|
||||||
# add all the vector Ids to the brains_vectors with the given brain.brain_id
|
|
||||||
brain.update_brain_with_file(
|
|
||||||
file_sha1=input_brain.file_sha1 # pyright: ignore reportPrivateUsage=none
|
|
||||||
)
|
|
||||||
|
|
||||||
brain.update_brain_fields(commons, brain) # pyright: ignore reportPrivateUsage=none
|
return {"message": f"Brain {brain_id} has been set as default brain."}
|
||||||
return {"message": f"Brain {brain_id} has been updated."}
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
from models.brains import get_default_user_brain
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_default_brain(client, api_key):
|
def test_retrieve_default_brain(client, api_key):
|
||||||
# Making a GET request to the /brains/default/ endpoint
|
# Making a GET request to the /brains/default/ endpoint
|
||||||
@ -161,3 +163,55 @@ def test_delete_all_brains_and_get_default_brain(client, api_key):
|
|||||||
# Assert that the response status code is 200 (HTTP OK)
|
# Assert that the response status code is 200 (HTTP OK)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "Default brain"
|
assert response.json()["name"] == "Default brain"
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_as_default_brain_endpoint(client, api_key):
|
||||||
|
random_brain_name = "".join(
|
||||||
|
random.choices(string.ascii_letters + string.digits, k=10)
|
||||||
|
)
|
||||||
|
# Set up the request payload
|
||||||
|
payload = {
|
||||||
|
"name": random_brain_name,
|
||||||
|
"status": "public",
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Making a POST request to the /brains/ endpoint
|
||||||
|
response = client.post(
|
||||||
|
"/brains/",
|
||||||
|
json=payload,
|
||||||
|
headers={"Authorization": "Bearer " + api_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
brain_id = response_data["id"]
|
||||||
|
|
||||||
|
# Make a POST request to set the brain as default for the user
|
||||||
|
response = client.post(
|
||||||
|
f"/brains/{brain_id}/default",
|
||||||
|
headers={"Authorization": "Bearer " + api_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that the response status code is 200 (HTTP OK)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Assert the response message
|
||||||
|
assert response.json() == {
|
||||||
|
"message": f"Brain {brain_id} has been set as default brain."
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if the brain is now the default for the user
|
||||||
|
|
||||||
|
# Send a request to get user information
|
||||||
|
response = client.get("/user", headers={"Authorization": "Bearer " + api_key})
|
||||||
|
# Assert that the response contains the expected fields
|
||||||
|
user_info = response.json()
|
||||||
|
user_id = user_info["id"]
|
||||||
|
|
||||||
|
default_brain = get_default_user_brain(user_id)
|
||||||
|
assert default_brain is not None
|
||||||
|
assert default_brain["id"] == brain_id
|
||||||
|
assert default_brain["default_brain"] is True
|
||||||
|
@ -198,7 +198,7 @@ describe("useBrainApi", () => {
|
|||||||
} = renderHook(() => useBrainApi());
|
} = renderHook(() => useBrainApi());
|
||||||
const brainId = "123";
|
const brainId = "123";
|
||||||
await setAsDefaultBrain(brainId);
|
await setAsDefaultBrain(brainId);
|
||||||
expect(axiosPutMock).toHaveBeenCalledTimes(1);
|
expect(axiosPostMock).toHaveBeenCalledTimes(1);
|
||||||
expect(axiosPutMock).toHaveBeenCalledWith(`/brains/${brainId}/default`);
|
expect(axiosPostMock).toHaveBeenCalledWith(`/brains/${brainId}/default`);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -131,5 +131,5 @@ export const setAsDefaultBrain = async (
|
|||||||
brainId: string,
|
brainId: string,
|
||||||
axiosInstance: AxiosInstance
|
axiosInstance: AxiosInstance
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
await axiosInstance.put(`/brains/${brainId}/default`);
|
await axiosInstance.post(`/brains/${brainId}/default`);
|
||||||
};
|
};
|
||||||
|
@ -107,9 +107,8 @@ export const AddBrainModal = (): JSX.Element => {
|
|||||||
</label>
|
</label>
|
||||||
<input
|
<input
|
||||||
type="range"
|
type="range"
|
||||||
min="256"
|
min="10"
|
||||||
max={defineMaxTokens(model)}
|
max={defineMaxTokens(model)}
|
||||||
step="32"
|
|
||||||
value={maxTokens}
|
value={maxTokens}
|
||||||
{...register("maxTokens")}
|
{...register("maxTokens")}
|
||||||
/>
|
/>
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
/* eslint-disable max-lines */
|
/* eslint-disable max-lines */
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { FormEvent, useState } from "react";
|
import { FormEvent, useEffect, useState } from "react";
|
||||||
import { useForm } from "react-hook-form";
|
import { useForm } from "react-hook-form";
|
||||||
|
|
||||||
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
|
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
|
||||||
import { useBrainConfig } from "@/lib/context/BrainConfigProvider";
|
import { useBrainConfig } from "@/lib/context/BrainConfigProvider";
|
||||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||||
|
import { defineMaxTokens } from "@/lib/helpers/defineMexTokens";
|
||||||
import { useToast } from "@/lib/hooks";
|
import { useToast } from "@/lib/hooks";
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||||
@ -23,7 +24,7 @@ export const useAddBrainModal = () => {
|
|||||||
setDefault: false,
|
setDefault: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
const { register, getValues, reset, watch } = useForm({
|
const { register, getValues, reset, watch, setValue } = useForm({
|
||||||
defaultValues,
|
defaultValues,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -32,20 +33,14 @@ export const useAddBrainModal = () => {
|
|||||||
const temperature = watch("temperature");
|
const temperature = watch("temperature");
|
||||||
const maxTokens = watch("maxTokens");
|
const maxTokens = watch("maxTokens");
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setValue("maxTokens", Math.min(maxTokens, defineMaxTokens(model)));
|
||||||
|
}, [maxTokens, model, setValue]);
|
||||||
|
|
||||||
const handleSubmit = async (e: FormEvent) => {
|
const handleSubmit = async (e: FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
const { name, description, setDefault } = getValues();
|
const { name, description, setDefault } = getValues();
|
||||||
|
|
||||||
console.log({
|
|
||||||
name,
|
|
||||||
description,
|
|
||||||
maxTokens,
|
|
||||||
model,
|
|
||||||
setDefault,
|
|
||||||
openAiKey,
|
|
||||||
temperature,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (name.trim() === "" || isPending) {
|
if (name.trim() === "" || isPending) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
31
scripts/202307241530031_add_fields_to_brain.sql
Normal file
31
scripts/202307241530031_add_fields_to_brain.sql
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- Change max_tokens type to INT
|
||||||
|
ALTER TABLE brains ALTER COLUMN max_tokens TYPE INT USING max_tokens::INT;
|
||||||
|
|
||||||
|
-- Add or rename the api_key column to openai_api_key
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
BEGIN
|
||||||
|
-- Check if the api_key column exists
|
||||||
|
IF EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'brains' AND column_name = 'api_key') THEN
|
||||||
|
-- Rename the api_key column to openai_api_key
|
||||||
|
ALTER TABLE brains RENAME COLUMN api_key TO openai_api_key;
|
||||||
|
ELSE
|
||||||
|
-- Create the openai_api_key column if it doesn't exist
|
||||||
|
ALTER TABLE brains ADD COLUMN openai_api_key TEXT;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
-- Add description column
|
||||||
|
ALTER TABLE brains ADD COLUMN description TEXT;
|
||||||
|
|
||||||
|
-- Update migrations table
|
||||||
|
INSERT INTO migrations (name)
|
||||||
|
SELECT '202307241530031_add_fields_to_brain'
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM migrations WHERE name = '202307241530031_add_fields_to_brain'
|
||||||
|
);
|
||||||
|
|
||||||
|
COMMIT;
|
@ -126,14 +126,15 @@ CREATE TABLE IF NOT EXISTS api_keys(
|
|||||||
is_active BOOLEAN DEFAULT true
|
is_active BOOLEAN DEFAULT true
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Create brains table
|
|
||||||
CREATE TABLE IF NOT EXISTS brains (
|
CREATE TABLE IF NOT EXISTS brains (
|
||||||
brain_id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
|
brain_id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
|
||||||
name TEXT,
|
name TEXT NOT NULL,
|
||||||
status TEXT,
|
status TEXT,
|
||||||
|
description TEXT,
|
||||||
model TEXT,
|
model TEXT,
|
||||||
max_tokens TEXT,
|
max_tokens INT,
|
||||||
temperature FLOAT
|
temperature FLOAT,
|
||||||
|
openai_api_key TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Create brains X users table
|
-- Create brains X users table
|
||||||
@ -193,7 +194,7 @@ CREATE TABLE IF NOT EXISTS migrations (
|
|||||||
);
|
);
|
||||||
|
|
||||||
INSERT INTO migrations (name)
|
INSERT INTO migrations (name)
|
||||||
SELECT '20230717173000_add_get_user_id_by_user_email'
|
SELECT '202307241530031_add_fields_to_brain'
|
||||||
WHERE NOT EXISTS (
|
WHERE NOT EXISTS (
|
||||||
SELECT 1 FROM migrations WHERE name = '20230717173000_add_get_user_id_by_user_email'
|
SELECT 1 FROM migrations WHERE name = '202307241530031_add_fields_to_brain'
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user