feat: add custom prompt fields on brain setting pages (#837)

* feat(sdk): add prompt apis to sdk

* feat: implement prompt creation-n

* feat: add brain custom prompt fields

* fix: change tables creation order
This commit is contained in:
Mamadou DICKO 2023-08-03 15:41:24 +02:00 committed by GitHub
parent a87cd113f7
commit 99a3fa9b29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 377 additions and 75 deletions

View File

@ -6,7 +6,7 @@ import { FaSpinner } from "react-icons/fa";
import Button from "@/lib/components/ui/Button";
import { Divider } from "@/lib/components/ui/Divider";
import Field from "@/lib/components/ui/Field";
import { TextArea } from "@/lib/components/ui/TextField";
import { TextArea } from "@/lib/components/ui/TextArea";
import { models, paidModels } from "@/lib/context/BrainConfigProvider/types";
import { defineMaxTokens } from "@/lib/helpers/defineMexTokens";
@ -29,6 +29,8 @@ export const SettingsTab = ({ brainId }: SettingsTabProps): JSX.Element => {
isUpdating,
isDefaultBrain,
formRef,
promptId,
removeBrainPrompt,
} = useSettingsTab({ brainId });
return (
@ -126,6 +128,26 @@ export const SettingsTab = ({ brainId }: SettingsTabProps): JSX.Element => {
{...register("maxTokens")}
/>
</fieldset>
<Divider text="prompt" />
<Field
label="Prompt title"
placeholder="My awesome prompt name"
autoComplete="off"
className="flex-1"
{...register("prompt.title")}
/>
<TextArea
label="Prompt content"
placeholder="As an AI, your..."
autoComplete="off"
className="flex-1"
{...register("prompt.content")}
/>
{promptId !== "" && (
<Button disabled={isUpdating} onClick={() => void removeBrainPrompt()}>
Remove prompt
</Button>
)}
<div className="flex flex-row justify-end flex-1 w-full mt-8">
{isUpdating && <FaSpinner className="animate-spin" />}
{isUpdating && (

View File

@ -6,6 +6,7 @@ import { useEffect, useRef, useState } from "react";
import { useForm } from "react-hook-form";
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
import { usePromptApi } from "@/lib/api/prompt/usePromptApi";
import { useBrainConfig } from "@/lib/context/BrainConfigProvider";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { Brain } from "@/lib/context/BrainProvider/types";
@ -26,67 +27,106 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
const { config } = useBrainConfig();
const { fetchAllBrains, fetchDefaultBrain, defaultBrainId } =
useBrainContext();
const { getPrompt, updatePrompt, createPrompt } = usePromptApi();
const defaultValues = {
...config,
name: "",
description: "",
setDefault: false,
prompt_id: "",
prompt: {
title: "",
content: "",
},
};
const {
register,
getValues,
watch,
setValue,
reset,
formState: { dirtyFields },
} = useForm({
defaultValues,
});
useEffect(() => {
const fetchBrain = async () => {
const brain = await getBrain(brainId);
if (brain === undefined) {
return;
}
for (const key in brain) {
const brainKey = key as keyof Brain;
if (!(key in brain)) {
return;
}
if (brainKey === "max_tokens" && brain["max_tokens"] !== undefined) {
setValue("maxTokens", brain["max_tokens"]);
continue;
}
if (
brainKey === "openai_api_key" &&
brain["openai_api_key"] !== undefined
) {
setValue("openAiKey", brain["openai_api_key"]);
continue;
}
// @ts-expect-error bad type inference from typescript
// eslint-disable-next-line
setValue(key, brain[key]);
}
};
void fetchBrain();
}, []);
const isDefaultBrain = defaultBrainId === brainId;
const promptId = watch("prompt_id");
const openAiKey = watch("openAiKey");
const model = watch("model");
const temperature = watch("temperature");
const maxTokens = watch("maxTokens");
const fetchBrain = async () => {
const brain = await getBrain(brainId);
if (brain === undefined) {
return;
}
for (const key in brain) {
const brainKey = key as keyof Brain;
if (!(key in brain)) {
return;
}
if (brainKey === "max_tokens" && brain["max_tokens"] !== undefined) {
setValue("maxTokens", brain["max_tokens"]);
continue;
}
if (
brainKey === "openai_api_key" &&
brain["openai_api_key"] !== undefined
) {
setValue("openAiKey", brain["openai_api_key"]);
continue;
}
// @ts-expect-error bad type inference from typescript
// eslint-disable-next-line
if (Boolean(brain[key])) setValue(key, brain[key]);
}
};
useEffect(() => {
void fetchBrain();
}, []);
useEffect(() => {
setValue("maxTokens", Math.min(maxTokens, defineMaxTokens(model)));
}, [maxTokens, model, setValue]);
useEffect(() => {
const handleKeyPress = (event: KeyboardEvent) => {
if (event.key === "Enter") {
event.preventDefault();
void handleSubmit();
}
};
formRef.current?.addEventListener("keydown", handleKeyPress);
return () => {
formRef.current?.removeEventListener("keydown", handleKeyPress);
};
}, [formRef.current]);
const fetchPrompt = async () => {
if (promptId === "") {
return;
}
const prompt = await getPrompt(promptId);
if (prompt === undefined) {
return;
}
setValue("prompt", prompt);
};
useEffect(() => {
void fetchPrompt();
}, [promptId]);
const setAsDefaultBrainHandler = async () => {
try {
setIsSettingHasDefault(true);
@ -117,6 +157,43 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
}
};
const removeBrainPrompt = async () => {
try {
setIsUpdating(true);
await updateBrain(brainId, {
prompt_id: null,
});
setValue("prompt", {
title: "",
content: "",
});
reset();
void fetchBrain();
publish({
variant: "success",
text: "Prompt removed successfully",
});
} catch (err) {
publish({
variant: "danger",
text: "Error while removing prompt",
});
} finally {
setIsUpdating(false);
}
};
const promptHandler = async () => {
const { prompt } = getValues();
if (dirtyFields["prompt"]) {
await updatePrompt(promptId, {
title: prompt.title,
content: prompt.content,
});
}
};
const handleSubmit = async () => {
const hasChanges = Object.keys(dirtyFields).length > 0;
@ -140,14 +217,55 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
const {
maxTokens: max_tokens,
openAiKey: openai_api_key,
prompt,
...otherConfigs
} = getValues();
await updateBrain(brainId, {
...otherConfigs,
max_tokens,
openai_api_key,
});
if (
dirtyFields["prompt"] &&
(prompt.content === "" || prompt.title === "")
) {
publish({
variant: "warning",
text: "Prompt title and content are required",
});
return;
}
if (dirtyFields["prompt"]) {
if (promptId === "") {
otherConfigs["prompt_id"] = (
await createPrompt({
title: prompt.title,
content: prompt.content,
})
).id;
await updateBrain(brainId, {
...otherConfigs,
max_tokens,
openai_api_key,
});
void fetchBrain();
} else {
await Promise.all([
updateBrain(brainId, {
...otherConfigs,
max_tokens,
openai_api_key,
}),
promptHandler(),
]);
}
return;
} else {
await updateBrain(brainId, {
...otherConfigs,
max_tokens,
openai_api_key,
});
}
publish({
variant: "success",
@ -176,22 +294,6 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
setIsUpdating(false);
}
};
const isDefaultBrain = defaultBrainId === brainId;
useEffect(() => {
const handleKeyPress = (event: KeyboardEvent) => {
if (event.key === "Enter") {
event.preventDefault();
void handleSubmit();
}
};
formRef.current?.addEventListener("keydown", handleKeyPress);
return () => {
formRef.current?.removeEventListener("keydown", handleKeyPress);
};
}, [formRef.current]);
return {
handleSubmit,
@ -205,5 +307,7 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
isSettingAsDefault,
isDefaultBrain,
formRef,
promptId,
removeBrainPrompt,
};
};

View File

@ -12,6 +12,7 @@ export type CreateBrainInput = {
temperature?: number;
max_tokens?: number;
openai_api_key?: string;
prompt_id?: string | null;
};
export type UpdateBrainInput = Partial<CreateBrainInput>;

View File

@ -0,0 +1,70 @@
import { renderHook } from "@testing-library/react";
import { describe, expect, it, vi } from "vitest";
import { CreatePromptProps, PromptUpdatableProperties } from "../prompt";
import { usePromptApi } from "../usePromptApi";
const axiosPostMock = vi.fn(() => ({}));
const axiosGetMock = vi.fn(() => ({}));
const axiosPutMock = vi.fn(() => ({}));
vi.mock("@/lib/hooks", () => ({
useAxios: () => ({
axiosInstance: {
post: axiosPostMock,
get: axiosGetMock,
put: axiosPutMock,
},
}),
}));
describe("usePromptApi", () => {
it("should call createPrompt with the correct parameters", async () => {
const prompt: CreatePromptProps = {
title: "Test Prompt",
content: "Test Content",
};
axiosPostMock.mockReturnValue({ data: {} });
const {
result: {
current: { createPrompt },
},
} = renderHook(() => usePromptApi());
await createPrompt(prompt);
expect(axiosPostMock).toHaveBeenCalledTimes(1);
expect(axiosPostMock).toHaveBeenCalledWith("/prompts", prompt);
});
it("should call getPrompt with the correct parameters", async () => {
const promptId = "test-prompt-id";
axiosGetMock.mockReturnValue({ data: {} });
const {
result: {
current: { getPrompt },
},
} = renderHook(() => usePromptApi());
await getPrompt(promptId);
expect(axiosGetMock).toHaveBeenCalledTimes(1);
expect(axiosGetMock).toHaveBeenCalledWith(`/prompts/${promptId}`);
});
it("should call updatePrompt with the correct parameters", async () => {
const promptId = "test-prompt-id";
const prompt: PromptUpdatableProperties = {
title: "Test Prompt",
content: "Test Content",
};
axiosPutMock.mockReturnValue({ data: {} });
const {
result: {
current: { updatePrompt },
},
} = renderHook(() => usePromptApi());
await updatePrompt(promptId, prompt);
expect(axiosPutMock).toHaveBeenCalledTimes(1);
expect(axiosPutMock).toHaveBeenCalledWith(`/prompts/${promptId}`, prompt);
});
});

View File

@ -0,0 +1,34 @@
import { AxiosInstance } from "axios";
import { Prompt } from "@/lib/types/Prompt";
export type CreatePromptProps = {
title: string;
content: string;
};
export const createPrompt = async (
prompt: CreatePromptProps,
axiosInstance: AxiosInstance
): Promise<Prompt> => {
return (await axiosInstance.post<Prompt>("/prompts", prompt)).data;
};
export const getPrompt = async (
promptId: string,
axiosInstance: AxiosInstance
): Promise<Prompt | undefined> => {
return (await axiosInstance.get<Prompt>(`/prompts/${promptId}`)).data;
};
export type PromptUpdatableProperties = {
title: string;
content: string;
};
export const updatePrompt = async (
promptId: string,
prompt: PromptUpdatableProperties,
axiosInstance: AxiosInstance
): Promise<Prompt> => {
return (await axiosInstance.put<Prompt>(`/prompts/${promptId}`, prompt)).data;
};

View File

@ -0,0 +1,22 @@
import { useAxios } from "@/lib/hooks";
import {
createPrompt,
CreatePromptProps,
getPrompt,
PromptUpdatableProperties,
updatePrompt,
} from "./prompt";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const usePromptApi = () => {
const { axiosInstance } = useAxios();
return {
createPrompt: async (prompt: CreatePromptProps) =>
createPrompt(prompt, axiosInstance),
getPrompt: async (promptId: string) => getPrompt(promptId, axiosInstance),
updatePrompt: async (promptId: string, prompt: PromptUpdatableProperties) =>
updatePrompt(promptId, prompt, axiosInstance),
};
};

View File

@ -8,7 +8,8 @@ import { models, paidModels } from "@/lib/context/BrainConfigProvider/types";
import { defineMaxTokens } from "@/lib/helpers/defineMexTokens";
import { useAddBrainModal } from "./hooks/useAddBrainModal";
import { TextArea } from "../ui/TextField";
import { Divider } from "../ui/Divider";
import { TextArea } from "../ui/TextArea";
export const AddBrainModal = (): JSX.Element => {
const {
@ -38,7 +39,10 @@ export const AddBrainModal = (): JSX.Element => {
CloseTrigger={<div />}
>
<form
onSubmit={(e) => void handleSubmit(e)}
onSubmit={(e) => {
e.preventDefault();
void handleSubmit();
}}
className="my-10 flex flex-col items-center gap-2"
>
<Field
@ -111,6 +115,21 @@ export const AddBrainModal = (): JSX.Element => {
{...register("maxTokens")}
/>
</fieldset>
<Divider text="Custom prompt" />
<Field
label="Prompt title"
placeholder="My awesome prompt name"
autoComplete="off"
className="flex-1"
{...register("prompt.title")}
/>
<TextArea
label="Prompt content"
placeholder="As an AI, your..."
autoComplete="off"
className="flex-1"
{...register("prompt.content")}
/>
<div className="flex flex-row justify-start w-full mt-4">
<label className="flex items-center">
<span className="mr-2 text-gray-700">Set as default brain</span>

View File

@ -1,9 +1,10 @@
/* eslint-disable max-lines */
import axios from "axios";
import { FormEvent, useEffect, useState } from "react";
import { useEffect, useState } from "react";
import { useForm } from "react-hook-form";
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
import { usePromptApi } from "@/lib/api/prompt/usePromptApi";
import { useBrainConfig } from "@/lib/context/BrainConfigProvider";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { defineMaxTokens } from "@/lib/helpers/defineMexTokens";
@ -15,6 +16,7 @@ export const useAddBrainModal = () => {
const { publish } = useToast();
const { createBrain, setActiveBrain } = useBrainContext();
const { setAsDefaultBrain } = useBrainApi();
const { createPrompt } = usePromptApi();
const [isShareModalOpen, setIsShareModalOpen] = useState(false);
const { config } = useBrainConfig();
const defaultValues = {
@ -22,6 +24,10 @@ export const useAddBrainModal = () => {
name: "",
description: "",
setDefault: false,
prompt: {
title: "",
content: "",
},
};
const { register, getValues, reset, watch, setValue } = useForm({
@ -37,8 +43,17 @@ export const useAddBrainModal = () => {
setValue("maxTokens", Math.min(maxTokens, defineMaxTokens(model)));
}, [maxTokens, model, setValue]);
const handleSubmit = async (e: FormEvent) => {
e.preventDefault();
const getCreatingBrainPromptId = async (): Promise<string | undefined> => {
const { prompt } = getValues();
if (prompt.title.trim() !== "" && prompt.content.trim() !== "") {
return (await createPrompt(prompt)).id;
}
return undefined;
};
const handleSubmit = async () => {
const { name, description, setDefault } = getValues();
if (name.trim() === "" || isPending) {
@ -47,6 +62,9 @@ export const useAddBrainModal = () => {
try {
setIsPending(true);
const prompt_id = await getCreatingBrainPromptId();
const createdBrainId = await createBrain({
name,
description,
@ -54,6 +72,7 @@ export const useAddBrainModal = () => {
model,
openai_api_key: openAiKey,
temperature,
prompt_id,
});
if (createdBrainId === undefined) {
@ -92,12 +111,13 @@ export const useAddBrainModal = () => {
).data.detail
)}`,
});
} else {
publish({
variant: "danger",
text: `${JSON.stringify(err)}`,
});
return;
}
publish({
variant: "danger",
text: `${JSON.stringify(err)}`,
});
} finally {
setIsPending(false);
}

View File

@ -24,6 +24,7 @@ const defaultBrainConfig: BrainConfig = {
openAiKey: undefined,
supabaseKey: undefined,
supabaseUrl: undefined,
prompt_id: undefined,
};
export const BrainConfigProvider = ({

View File

@ -1,3 +1,5 @@
import { UUID } from "crypto";
export type BrainConfig = {
model: Model;
temperature: number;
@ -8,6 +10,7 @@ export type BrainConfig = {
anthropicKey?: string;
supabaseUrl?: string;
supabaseKey?: string;
prompt_id?: UUID;
};
type OptionalConfig = { [K in keyof BrainConfig]?: BrainConfig[K] | undefined };

View File

@ -16,6 +16,7 @@ export type Brain = {
temperature?: number;
openai_api_key?: string;
description?: string;
prompt_id?: string | null;
};
export type MinimalBrainForUser = {

View File

@ -0,0 +1,5 @@
export type Prompt = {
id: string;
title: string;
content: string;
};

View File

@ -126,6 +126,15 @@ CREATE TABLE IF NOT EXISTS api_keys(
is_active BOOLEAN DEFAULT true
);
--- Create prompts table
CREATE TABLE IF NOT EXISTS prompts (
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
title VARCHAR(255),
content TEXT,
status VARCHAR(255) DEFAULT 'private'
);
--- Create brains table
CREATE TABLE IF NOT EXISTS brains (
brain_id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
name TEXT NOT NULL,
@ -176,15 +185,6 @@ CREATE TABLE IF NOT EXISTS user_identity (
);
--- Create prompts table
CREATE TABLE IF NOT EXISTS prompts (
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
title VARCHAR(255),
content TEXT,
status VARCHAR(255) DEFAULT 'private'
);
CREATE OR REPLACE FUNCTION public.get_user_email_by_user_id(user_id uuid)
RETURNS TABLE (email text)
SECURITY definer