feat: add user level open ai key management (#805)

* feat: add user user identity table

* feat: add user openai api key input

* feat: add encryption missing message

* chore: log more details about 422 errors

* docs(API): update api creation path

* feat: use user openai key if defined
This commit is contained in:
Mamadou DICKO 2023-08-01 09:24:57 +02:00 committed by GitHub
parent b72139af60
commit 7532b558c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 452 additions and 11 deletions

View File

@ -2,7 +2,8 @@ import os
import pypandoc import pypandoc
import sentry_sdk import sentry_sdk
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from logger import get_logger from logger import get_logger
from middlewares.cors import add_cors_middleware from middlewares.cors import add_cors_middleware
@ -53,3 +54,24 @@ async def http_exception_handler(_, exc):
status_code=exc.status_code, status_code=exc.status_code,
content={"detail": exc.detail}, content={"detail": exc.detail},
) )
# log more details about validation errors (422)
def handle_request_validation_error(app: FastAPI):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
logger.error(request, exc_str)
content = {
"status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
"message": exc_str,
"data": None,
}
return JSONResponse(
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
handle_request_validation_error(app)

View File

@ -0,0 +1,9 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class UserIdentity(BaseModel):
user_id: UUID
openai_api_key: Optional[str] = None

View File

@ -2,19 +2,20 @@ from typing import Optional
from uuid import UUID from uuid import UUID
from logger import get_logger from logger import get_logger
from models.settings import common_dependencies
from pydantic import BaseModel from pydantic import BaseModel
from models.settings import common_dependencies
logger = get_logger(__name__) logger = get_logger(__name__)
# [TODO] Rename the user table and its references to 'user_usage'
class User(BaseModel): class User(BaseModel):
id: UUID id: UUID
email: Optional[str] email: Optional[str]
user_openai_api_key: Optional[str] = None user_openai_api_key: Optional[str] = None
requests_count: int = 0 requests_count: int = 0
# [TODO] Rename the user table and its references to 'user_usage'
def create_user(self, date): def create_user(self, date):
""" """
Create a new user entry in the database Create a new user entry in the database

View File

@ -0,0 +1,13 @@
from models.settings import common_dependencies
from models.user_identity import UserIdentity
def create_user_identity(user_identity: UserIdentity) -> UserIdentity:
commons = common_dependencies()
user_identity_dict = user_identity.dict()
user_identity_dict["user_id"] = str(user_identity.user_id)
response = (
commons["supabase"].from_("user_identity").insert(user_identity_dict).execute()
)
return UserIdentity(**response.data[0])

View File

@ -0,0 +1,21 @@
from uuid import UUID
from models.settings import common_dependencies
from models.user_identity import UserIdentity
from repository.user_identity.create_user_identity import create_user_identity
def get_user_identity(user_id: UUID) -> UserIdentity:
commons = common_dependencies()
response = (
commons["supabase"]
.from_("user_identity")
.select("*")
.filter("user_id", "eq", user_id)
.execute()
)
if len(response.data) == 0:
return create_user_identity(UserIdentity(user_id=user_id))
return UserIdentity(**response.data[0])

View File

@ -0,0 +1,36 @@
from typing import Optional
from uuid import UUID
from models.settings import common_dependencies
from models.user_identity import UserIdentity
from pydantic import BaseModel
from repository.user_identity.create_user_identity import (
create_user_identity,
)
class UserIdentityUpdatableProperties(BaseModel):
openai_api_key: Optional[str]
def update_user_identity(
user_id: UUID,
user_identity_updatable_properties: UserIdentityUpdatableProperties,
) -> UserIdentity:
commons = common_dependencies()
response = (
commons["supabase"]
.from_("user_identity")
.update(user_identity_updatable_properties.__dict__)
.filter("user_id", "eq", user_id)
.execute()
)
if len(response.data) == 0:
user_identity = UserIdentity(
user_id=user_id,
openai_api_key=user_identity_updatable_properties.openai_api_key,
)
return create_user_identity(user_identity)
return UserIdentity(**response.data[0])

View File

@ -7,6 +7,7 @@ from models.brains import Brain
from models.files import File from models.files import File
from models.settings import common_dependencies from models.settings import common_dependencies
from models.users import User from models.users import User
from repository.user_identity.get_user_identity import get_user_identity
from utils.file import convert_bytes, get_file_size from utils.file import convert_bytes, get_file_size
from utils.processors import filter_file from utils.processors import filter_file
@ -59,12 +60,19 @@ async def upload_file(
"type": "error", "type": "error",
} }
else: else:
openai_api_key = request.headers.get("Openai-Api-Key", None)
if openai_api_key is None:
openai_api_key = brain.get_brain_details()["openai_api_key"]
if openai_api_key is None:
openai_api_key = get_user_identity(current_user.id).openai_api_key
message = await filter_file( message = await filter_file(
commons, commons,
file, file,
enable_summarization, enable_summarization,
brain_id=brain_id, brain_id=brain_id,
openai_api_key=request.headers.get("Openai-Api-Key", None), openai_api_key=openai_api_key,
) )
return message return message

View File

@ -5,7 +5,13 @@ from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from models.brains import Brain, get_default_user_brain from models.brains import Brain, get_default_user_brain
from models.settings import BrainRateLimiting from models.settings import BrainRateLimiting
from models.user_identity import UserIdentity
from models.users import User from models.users import User
from repository.user_identity.get_user_identity import get_user_identity
from repository.user_identity.update_user_identity import (
UserIdentityUpdatableProperties,
update_user_identity,
)
user_router = APIRouter() user_router = APIRouter()
@ -56,3 +62,32 @@ async def get_user_endpoint(
"requests_stats": requests_stats, "requests_stats": requests_stats,
"date": date, "date": date,
} }
@user_router.put(
"/user/identity",
dependencies=[Depends(AuthBearer())],
tags=["User"],
)
def update_user_identity_route(
user_identity_updatable_properties: UserIdentityUpdatableProperties,
current_user: User = Depends(get_current_user),
) -> UserIdentity:
"""
Update user identity.
"""
return update_user_identity(current_user.id, user_identity_updatable_properties)
@user_router.get(
"/user/identity",
dependencies=[Depends(AuthBearer())],
tags=["User"],
)
def get_user_identity_route(
current_user: User = Depends(get_current_user),
) -> UserIdentity:
"""
Get user identity.
"""
return get_user_identity(current_user.id)

View File

@ -9,25 +9,30 @@ sidebar_position: 1
**Swagger**: https://api.quivr.app/docs **Swagger**: https://api.quivr.app/docs
## Overview ## Overview
This documentation outlines the key points and usage instructions for interacting with the API backend. Please follow the guidelines below to use the backend services effectively. This documentation outlines the key points and usage instructions for interacting with the API backend. Please follow the guidelines below to use the backend services effectively.
## Usage Instructions ## Usage Instructions
1. Standalone Backend 1. Standalone Backend
- The backend can now be used independently without the frontend application. - The backend can now be used independently without the frontend application.
- Users can interact with the API endpoints directly using API testing tools like Postman. - Users can interact with the API endpoints directly using API testing tools like Postman.
2. Generating API Key 2. Generating API Key
- To access the backend services, you need to sign in to the frontend application. - To access the backend services, you need to sign in to the frontend application.
- Once signed in, navigate to the `/config` page to generate a new API key. - Once signed in, navigate to the `/user` page to generate a new API key.
- The API key will be required to authenticate your requests to the backend. - The API key will be required to authenticate your requests to the backend.
3. Authenticating Requests 3. Authenticating Requests
- When making requests to the backend API, include the following header: - When making requests to the backend API, include the following header:
- `Authorization: Bearer {api_key}` - `Authorization: Bearer {api_key}`
- Replace `{api_key}` with the generated API key obtained from the frontend. - Replace `{api_key}` with the generated API key obtained from the frontend.
4. Future Plans 4. Future Plans
- The development team has plans to introduce additional features and improvements. - The development team has plans to introduce additional features and improvements.
- These include the ability to delete API keys and view the list of active keys. - These include the ability to delete API keys and view the list of active keys.
- The GitHub roadmap will provide more details on upcoming features, including addressing active issues. - The GitHub roadmap will provide more details on upcoming features, including addressing active issues.

View File

@ -5,7 +5,12 @@ import Button from "@/lib/components/ui/Button";
import { useApiKeyConfig } from "./hooks/useApiKeyConfig"; import { useApiKeyConfig } from "./hooks/useApiKeyConfig";
export const ApiKeyConfig = (): JSX.Element => { export const ApiKeyConfig = (): JSX.Element => {
const { apiKey, handleCopyClick, handleCreateClick } = useApiKeyConfig(); const {
apiKey,
handleCopyClick,
handleCreateClick,
} = useApiKeyConfig();
return ( return (
<> <>

View File

@ -6,6 +6,7 @@ import { useEventTracking } from "@/services/analytics/useEventTracking";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useApiKeyConfig = () => { export const useApiKeyConfig = () => {
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const [openAiApiKey, setOpenAiApiKey] = useState("");
const { track } = useEventTracking(); const { track } = useEventTracking();
const { createApiKey } = useAuthApi(); const { createApiKey } = useAuthApi();
@ -38,5 +39,7 @@ export const useApiKeyConfig = () => {
handleCreateClick, handleCreateClick,
apiKey, apiKey,
handleCopyClick, handleCopyClick,
openAiApiKey,
setOpenAiApiKey,
}; };
}; };

View File

@ -1,12 +1,25 @@
/* eslint-disable max-lines */
"use client"; "use client";
import Button from "@/lib/components/ui/Button"; import Button from "@/lib/components/ui/Button";
import { Divider } from "@/lib/components/ui/Divider"; import { Divider } from "@/lib/components/ui/Divider";
import Field from "@/lib/components/ui/Field";
import { useApiKeyConfig } from "./hooks/useApiKeyConfig"; import { useApiKeyConfig } from "./hooks/useApiKeyConfig";
export const ApiKeyConfig = (): JSX.Element => { export const ApiKeyConfig = (): JSX.Element => {
const { apiKey, handleCopyClick, handleCreateClick } = useApiKeyConfig(); const {
apiKey,
handleCopyClick,
handleCreateClick,
openAiApiKey,
setOpenAiApiKey,
changeOpenAiApiKey,
changeOpenAiApiKeyRequestPending,
userIdentity,
removeOpenAiApiKey,
hasOpenAiApiKey,
} = useApiKeyConfig();
return ( return (
<> <>
@ -36,6 +49,53 @@ export const ApiKeyConfig = (): JSX.Element => {
</div> </div>
)} )}
</div> </div>
<Divider text="OpenAI Key" className="mt-4 mb-4" />
<div className="flex mb-4 justify-center items-center mt-5">
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative max-w-md">
<span className="block sm:inline">
Your api key will be saved in our data. We will not use it for any
other purpose. However,{" "}
<strong>
we {"don't"} have not implemented any encryption logic yet
</strong>
</span>
</div>
</div>
<form
onSubmit={(event) => {
event.preventDefault();
void changeOpenAiApiKey();
}}
>
<Field
name="openAiApiKey"
placeholder="Open AI Key"
className="w-full"
value={openAiApiKey ?? ""}
data-testid="open-ai-api-key"
onChange={(e) => setOpenAiApiKey(e.target.value)}
/>
<div className="mt-4 flex flex-row justify-between">
{hasOpenAiApiKey && (
<Button
isLoading={changeOpenAiApiKeyRequestPending}
variant="secondary"
onClick={() => void removeOpenAiApiKey()}
>
Remove Key
</Button>
)}
<Button
data-testid="save-open-ai-api-key"
isLoading={changeOpenAiApiKeyRequestPending}
disabled={openAiApiKey === userIdentity?.openai_api_key}
>
Save Key
</Button>
</div>
</form>
</> </>
); );
}; };

View File

@ -22,8 +22,11 @@ describe("ApiKeyConfig", () => {
}); });
it("should render ApiConfig Component", () => { it("should render ApiConfig Component", () => {
const { getByText } = render(<ApiKeyConfig />); const { getByText, getByTestId } = render(<ApiKeyConfig />);
expect(getByText("API Key Config")).toBeDefined(); expect(getByText("API Key Config")).toBeDefined();
expect(getByText("OpenAI Key")).toBeDefined();
expect(getByTestId("open-ai-api-key")).toBeDefined();
expect(getByTestId("save-open-ai-api-key")).toBeDefined();
}); });
it("renders 'Create New Key' button when apiKey is empty", () => { it("renders 'Create New Key' button when apiKey is empty", () => {

View File

@ -6,6 +6,12 @@ import { useApiKeyConfig } from "../useApiKeyConfig";
const createApiKeyMock = vi.fn(() => "dummyApiKey"); const createApiKeyMock = vi.fn(() => "dummyApiKey");
const trackMock = vi.fn((props: unknown) => ({ props })); const trackMock = vi.fn((props: unknown) => ({ props }));
const mockUseSupabase = vi.fn(() => ({
session: {
user: {},
},
}));
const useAuthApiMock = vi.fn(() => ({ const useAuthApiMock = vi.fn(() => ({
createApiKey: () => createApiKeyMock(), createApiKey: () => createApiKeyMock(),
})); }));
@ -20,6 +26,31 @@ vi.mock("@/lib/api/auth/useAuthApi", () => ({
vi.mock("@/services/analytics/useEventTracking", () => ({ vi.mock("@/services/analytics/useEventTracking", () => ({
useEventTracking: () => useEventTrackingMock(), useEventTracking: () => useEventTrackingMock(),
})); }));
vi.mock("@/lib/context/SupabaseProvider", () => ({
useSupabase: () => mockUseSupabase(),
}));
vi.mock("@/lib/hooks", async () => {
const actual = await vi.importActual<typeof import("@/lib/hooks")>(
"@/lib/hooks"
);
return {
...actual,
useAxios: () => ({
axiosInstance: {
put: vi.fn(() => ({})),
get: vi.fn(() => ({})),
},
}),
};
});
vi.mock("@/lib/context/BrainConfigProvider", () => ({
useBrainConfig: () => ({
config: {},
}),
}));
describe("useApiKeyConfig", () => { describe("useApiKeyConfig", () => {
afterEach(() => { afterEach(() => {

View File

@ -1,13 +1,32 @@
import { useState } from "react"; /* eslint-disable max-lines */
import { useEffect, useState } from "react";
import { useAuthApi } from "@/lib/api/auth/useAuthApi"; import { useAuthApi } from "@/lib/api/auth/useAuthApi";
import { useUserApi } from "@/lib/api/user/useUserApi";
import { UserIdentity } from "@/lib/api/user/user";
import { useToast } from "@/lib/hooks";
import { useEventTracking } from "@/services/analytics/useEventTracking"; import { useEventTracking } from "@/services/analytics/useEventTracking";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useApiKeyConfig = () => { export const useApiKeyConfig = () => {
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const [openAiApiKey, setOpenAiApiKey] = useState<string | null>();
const [
changeOpenAiApiKeyRequestPending,
setChangeOpenAiApiKeyRequestPending,
] = useState(false);
const { updateUserIdentity, getUserIdentity } = useUserApi();
const { track } = useEventTracking(); const { track } = useEventTracking();
const { createApiKey } = useAuthApi(); const { createApiKey } = useAuthApi();
const { publish } = useToast();
const [userIdentity, setUserIdentity] = useState<UserIdentity>();
const fetchUserIdentity = async () => {
setUserIdentity(await getUserIdentity());
};
useEffect(() => {
void fetchUserIdentity();
}, []);
const handleCreateClick = async () => { const handleCreateClick = async () => {
try { try {
@ -34,9 +53,65 @@ export const useApiKeyConfig = () => {
} }
}; };
const changeOpenAiApiKey = async () => {
try {
setChangeOpenAiApiKeyRequestPending(true);
await updateUserIdentity({
openai_api_key: openAiApiKey,
});
void fetchUserIdentity();
publish({
variant: "success",
text: "OpenAI API Key updated",
});
} catch (error) {
console.error(error);
} finally {
setChangeOpenAiApiKeyRequestPending(false);
}
};
const removeOpenAiApiKey = async () => {
try {
setChangeOpenAiApiKeyRequestPending(true);
await updateUserIdentity({
openai_api_key: null,
});
publish({
variant: "success",
text: "OpenAI API Key removed",
});
void fetchUserIdentity();
} catch (error) {
console.error(error);
} finally {
setChangeOpenAiApiKeyRequestPending(false);
}
};
useEffect(() => {
if (userIdentity?.openai_api_key !== undefined) {
setOpenAiApiKey(userIdentity.openai_api_key);
}
}, [userIdentity]);
const hasOpenAiApiKey =
userIdentity?.openai_api_key !== null &&
userIdentity?.openai_api_key !== undefined &&
userIdentity.openai_api_key !== "";
return { return {
handleCreateClick, handleCreateClick,
apiKey, apiKey,
handleCopyClick, handleCopyClick,
openAiApiKey,
setOpenAiApiKey,
changeOpenAiApiKey,
changeOpenAiApiKeyRequestPending,
userIdentity,
removeOpenAiApiKey,
hasOpenAiApiKey,
}; };
}; };

View File

@ -0,0 +1,48 @@
import { renderHook } from "@testing-library/react";
import { describe, expect, it, vi } from "vitest";
import { useUserApi } from "../useUserApi";
import { UserIdentityUpdatableProperties } from "../user";
const axiosPutMock = vi.fn(() => ({}));
const axiosGetMock = vi.fn(() => ({}));
vi.mock("@/lib/hooks", () => ({
useAxios: () => ({
axiosInstance: {
put: axiosPutMock,
get: axiosGetMock,
},
}),
}));
describe("useUserApi", () => {
it("should call updateUserIdentity with the correct parameters", async () => {
const {
result: {
current: { updateUserIdentity },
},
} = renderHook(() => useUserApi());
const userUpdatableProperties: UserIdentityUpdatableProperties = {
openai_api_key: "sk-xxx",
};
await updateUserIdentity(userUpdatableProperties);
expect(axiosPutMock).toHaveBeenCalledTimes(1);
expect(axiosPutMock).toHaveBeenCalledWith(
`/user/identity`,
userUpdatableProperties
);
});
it("should call getUserIdentity with the correct parameters", async () => {
const {
result: {
current: { getUserIdentity },
},
} = renderHook(() => useUserApi());
await getUserIdentity();
expect(axiosGetMock).toHaveBeenCalledTimes(1);
expect(axiosGetMock).toHaveBeenCalledWith(`/user/identity`);
});
});

View File

@ -0,0 +1,19 @@
import { useAxios } from "@/lib/hooks";
import {
getUserIdentity,
updateUserIdentity,
UserIdentityUpdatableProperties,
} from "./user";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useUserApi = () => {
const { axiosInstance } = useAxios();
return {
updateUserIdentity: async (
userIdentityUpdatableProperties: UserIdentityUpdatableProperties
) => updateUserIdentity(userIdentityUpdatableProperties, axiosInstance),
getUserIdentity: async () => getUserIdentity(axiosInstance),
};
};

View File

@ -0,0 +1,25 @@
import { AxiosInstance } from "axios";
import { UUID } from "crypto";
export type UserIdentityUpdatableProperties = {
openai_api_key?: string | null;
};
export type UserIdentity = {
openai_api_key?: string | null;
user_id: UUID;
};
export const updateUserIdentity = async (
userUpdatableProperties: UserIdentityUpdatableProperties,
axiosInstance: AxiosInstance
): Promise<UserIdentity> =>
axiosInstance.put(`/user/identity`, userUpdatableProperties);
export const getUserIdentity = async (
axiosInstance: AxiosInstance
): Promise<UserIdentity> => {
const { data } = await axiosInstance.get<UserIdentity>(`/user/identity`);
return data;
};

View File

@ -0,0 +1,16 @@
BEGIN;
-- Create user_identity table if it doesn't exist
CREATE TABLE IF NOT EXISTS user_identity (
user_id UUID PRIMARY KEY,
openai_api_key VARCHAR(255)
);
-- Insert migration record if it doesn't exist
INSERT INTO migrations (name)
SELECT '20230731172400_add_user_identity_table'
WHERE NOT EXISTS (
SELECT 1 FROM migrations WHERE name = '20230731172400_add_user_identity_table'
);
COMMIT;

View File

@ -167,6 +167,12 @@ CREATE TABLE IF NOT EXISTS brain_subscription_invitations (
FOREIGN KEY (brain_id) REFERENCES brains (brain_id) FOREIGN KEY (brain_id) REFERENCES brains (brain_id)
); );
--- Create user_identity table
CREATE TABLE IF NOT EXISTS user_identity (
user_id UUID PRIMARY KEY,
openai_api_key VARCHAR(255)
);
CREATE OR REPLACE FUNCTION public.get_user_email_by_user_id(user_id uuid) CREATE OR REPLACE FUNCTION public.get_user_email_by_user_id(user_id uuid)
RETURNS TABLE (email text) RETURNS TABLE (email text)
SECURITY definer SECURITY definer
@ -194,7 +200,7 @@ CREATE TABLE IF NOT EXISTS migrations (
); );
INSERT INTO migrations (name) INSERT INTO migrations (name)
SELECT '202307241530031_add_fields_to_brain' SELECT '20230731172400_add_user_identity_table'
WHERE NOT EXISTS ( WHERE NOT EXISTS (
SELECT 1 FROM migrations WHERE name = '202307241530031_add_fields_to_brain' SELECT 1 FROM migrations WHERE name = '20230731172400_add_user_identity_table'
); );