diff --git a/app/src/lib/ai/anthropicClient.ts b/app/src/lib/ai/anthropicClient.ts index 07a6d5cb0..f68532ef9 100644 --- a/app/src/lib/ai/anthropicClient.ts +++ b/app/src/lib/ai/anthropicClient.ts @@ -1,6 +1,6 @@ import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts'; import { fetch, Body } from '@tauri-apps/api/http'; -import type { AIClient, AnthropicModelName, PromptMessage } from '$lib/ai/types'; +import type { AIClient, AnthropicModelName, Prompt } from '$lib/ai/types'; type AnthropicAPIResponse = { content: { text: string }[] }; @@ -13,7 +13,7 @@ export class AnthropicAIClient implements AIClient { private modelName: AnthropicModelName ) {} - async evaluate(prompt: PromptMessage[]) { + async evaluate(prompt: Prompt) { const body = Body.json({ messages: prompt, max_tokens: 1024, diff --git a/app/src/lib/ai/butlerClient.ts b/app/src/lib/ai/butlerClient.ts index a228e83c8..56454ef24 100644 --- a/app/src/lib/ai/butlerClient.ts +++ b/app/src/lib/ai/butlerClient.ts @@ -1,5 +1,5 @@ import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; -import type { AIClient, ModelKind, PromptMessage } from '$lib/ai/types'; +import type { AIClient, ModelKind, Prompt } from '$lib/ai/types'; import type { HttpClient } from '$lib/backend/httpClient'; export class ButlerAIClient implements AIClient { @@ -12,7 +12,7 @@ export class ButlerAIClient implements AIClient { private modelKind: ModelKind ) {} - async evaluate(prompt: PromptMessage[]) { + async evaluate(prompt: Prompt) { const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { body: { messages: prompt, diff --git a/app/src/lib/ai/ollamaClient.ts b/app/src/lib/ai/ollamaClient.ts index 7c0369cc8..019b1d0fe 100644 --- a/app/src/lib/ai/ollamaClient.ts +++ b/app/src/lib/ai/ollamaClient.ts @@ -1,5 +1,5 @@ import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; -import { MessageRole, type PromptMessage, type AIClient } from '$lib/ai/types'; +import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types'; import { isNonEmptyObject } from '$lib/utils/typeguards'; import { fetch, Body, Response } from '@tauri-apps/api/http'; @@ -22,7 +22,7 @@ interface OllamaRequestOptions { interface OllamaChatRequest { model: string; - messages: PromptMessage[]; + messages: Prompt; stream: boolean; format?: 'json'; options?: OllamaRequestOptions; @@ -81,7 +81,7 @@ export class OllamaClient implements AIClient { private modelName: string ) {} - async evaluate(prompt: PromptMessage[]) { + async evaluate(prompt: Prompt) { const messages = this.formatPrompt(prompt); const response = await this.chat(messages); const rawResponse = JSON.parse(response.message.content); @@ -96,7 +96,7 @@ export class OllamaClient implements AIClient { * Appends a system message which instructs the model to respond using a particular JSON schema * Modifies the prompt's Assistant messages to make use of the correct schema */ - private formatPrompt(prompt: PromptMessage[]) { + private formatPrompt(prompt: Prompt) { const withFormattedResponses = prompt.map((promptMessage) => { if (promptMessage.role == MessageRole.Assistant) { return { @@ -146,7 +146,7 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}` * @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model. */ private async chat( - messages: PromptMessage[], + messages: Prompt, options?: OllamaRequestOptions ): Promise { const result = await this.fetchChat({ diff --git a/app/src/lib/ai/openAIClient.ts b/app/src/lib/ai/openAIClient.ts index 62ded36b8..09ea39d88 100644 --- a/app/src/lib/ai/openAIClient.ts +++ b/app/src/lib/ai/openAIClient.ts @@ -1,5 +1,5 @@ import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; -import type { OpenAIModelName, PromptMessage, AIClient } from '$lib/ai/types'; +import type { OpenAIModelName, Prompt, AIClient } from '$lib/ai/types'; import type OpenAI from 'openai'; export class OpenAIClient implements AIClient { @@ -11,7 +11,7 @@ export class OpenAIClient implements AIClient { private openAI: OpenAI ) {} - async evaluate(prompt: PromptMessage[]) { + async evaluate(prompt: Prompt) { const response = await this.openAI.chat.completions.create({ messages: prompt, model: this.modelName, diff --git a/app/src/lib/ai/promptService.ts b/app/src/lib/ai/promptService.ts new file mode 100644 index 000000000..912b99de7 --- /dev/null +++ b/app/src/lib/ai/promptService.ts @@ -0,0 +1,96 @@ +import { + LONG_DEFAULT_BRANCH_TEMPLATE, + SHORT_DEFAULT_BRANCH_TEMPLATE, + LONG_DEFAULT_COMMIT_TEMPLATE, + SHORT_DEFAULT_COMMIT_TEMPLATE +} from '$lib/ai/prompts'; +import { persisted, type Persisted } from '$lib/persisted/persisted'; +import { get } from 'svelte/store'; +import type { Prompt, Prompts, UserPrompt } from '$lib/ai/types'; + +enum PromptPersistedKey { + Branch = 'aiBranchPrompts', + Commit = 'aiCommitPrompts' +} + +export class PromptService { + get branchPrompts(): Prompts { + return { + defaultPrompt: LONG_DEFAULT_BRANCH_TEMPLATE, + userPrompts: persisted([], PromptPersistedKey.Branch) + }; + } + + get commitPrompts(): Prompts { + return { + defaultPrompt: LONG_DEFAULT_COMMIT_TEMPLATE, + userPrompts: persisted([], PromptPersistedKey.Commit) + }; + } + + selectedBranchPromptId(projectId: string): Persisted { + return persisted(undefined, `${PromptPersistedKey.Branch}-${projectId}`); + } + + selectedBranchPrompt(projectId: string): Prompt | undefined { + const id = get(this.selectedBranchPromptId(projectId)); + + if (!id) return; + + return this.findPrompt(get(this.branchPrompts.userPrompts), id); + } + + selectedCommitPromptId(projectId: string): Persisted { + return persisted(undefined, `${PromptPersistedKey.Commit}-${projectId}`); + } + + selectedCommitPrompt(projectId: string): Prompt | undefined { + const id = get(this.selectedCommitPromptId(projectId)); + + if (!id) return; + + return this.findPrompt(get(this.commitPrompts.userPrompts), id); + } + + findPrompt(prompts: UserPrompt[], promptId: string) { + const prompt = prompts.find((userPrompt) => userPrompt.id == promptId)?.prompt; + + if (!prompt) return; + if (this.promptMissingContent(prompt)) return; + + return prompt; + } + + promptEquals(prompt1: Prompt, prompt2: Prompt) { + if (prompt1.length != prompt2.length) return false; + + for (const indexPromptMessage of prompt1.entries()) { + const [index, promptMessage] = indexPromptMessage; + + if ( + promptMessage.role != prompt2[index].role || + promptMessage.content != prompt2[index].content + ) { + return false; + } + } + + return true; + } + + promptMissingContent(prompt: Prompt) { + for (const promptMessage of prompt) { + if (!promptMessage.content) return true; + } + + return false; + } + + createDefaultUserPrompt(type: 'commits' | 'branches'): UserPrompt { + return { + id: crypto.randomUUID(), + name: 'My Prompt', + prompt: type == 'branches' ? SHORT_DEFAULT_BRANCH_TEMPLATE : SHORT_DEFAULT_COMMIT_TEMPLATE + }; + } +} diff --git a/app/src/lib/ai/prompts.ts b/app/src/lib/ai/prompts.ts index fd774126c..7c0757ada 100644 --- a/app/src/lib/ai/prompts.ts +++ b/app/src/lib/ai/prompts.ts @@ -1,6 +1,6 @@ -import { type PromptMessage, MessageRole } from '$lib/ai/types'; +import { type Prompt, MessageRole } from '$lib/ai/types'; -export const SHORT_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ +export const SHORT_DEFAULT_COMMIT_TEMPLATE: Prompt = [ { role: MessageRole.User, content: `Please could you write a commit message for my changes. @@ -16,11 +16,14 @@ Do not start any lines with the hash symbol. %{emoji_style} Here is my git diff: -%{diff}` +\`\`\` +%{diff} +\`\`\` +` } ]; -export const LONG_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ +export const LONG_DEFAULT_COMMIT_TEMPLATE: Prompt = [ { role: MessageRole.User, content: `Please could you write a commit message for my changes. @@ -34,6 +37,7 @@ Do not start any lines with the hash symbol. Only respond with the commit message. Here is my git diff: +\`\`\` diff --git a/src/utils/typing.ts b/src/utils/typing.ts index 1cbfaa2..7aeebcf 100644 --- a/src/utils/typing.ts @@ -48,7 +52,9 @@ index 1cbfaa2..7aeebcf 100644 + check: (value: unknown) => value is T +): something is T[] { + return Array.isArray(something) && something.every(check); -+}` ++} +\`\`\` +` }, { role: MessageRole.Assistant, @@ -59,7 +65,7 @@ Added an utility function to check whether a given value is an array of a specif ...SHORT_DEFAULT_COMMIT_TEMPLATE ]; -export const SHORT_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ +export const SHORT_DEFAULT_BRANCH_TEMPLATE: Prompt = [ { role: MessageRole.User, content: `Please could you write a branch name for my changes. @@ -69,11 +75,14 @@ Branch names should contain a maximum of 5 words. Only respond with the branch name. Here is my git diff: -%{diff}` +\`\`\` +%{diff} +\`\`\` +` } ]; -export const LONG_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ +export const LONG_DEFAULT_BRANCH_TEMPLATE: Prompt = [ { role: MessageRole.User, content: `Please could you write a branch name for my changes. @@ -83,6 +92,7 @@ Branch names should contain a maximum of 5 words. Only respond with the branch name. Here is my git diff: +\`\`\` diff --git a/src/utils/typing.ts b/src/utils/typing.ts index 1cbfaa2..7aeebcf 100644 --- a/src/utils/typing.ts @@ -97,7 +107,9 @@ index 1cbfaa2..7aeebcf 100644 + check: (value: unknown) => value is T +): something is T[] { + return Array.isArray(something) && something.every(check); -+}` ++} +\`\`\` +` }, { role: MessageRole.Assistant, diff --git a/app/src/lib/ai/service.test.ts b/app/src/lib/ai/service.test.ts index 51611550f..c1849e554 100644 --- a/app/src/lib/ai/service.test.ts +++ b/app/src/lib/ai/service.test.ts @@ -8,7 +8,7 @@ import { ModelKind, OpenAIModelName, type AIClient, - type PromptMessage + type Prompt } from '$lib/ai/types'; import { HttpClient } from '$lib/backend/httpClient'; import * as toasts from '$lib/utils/toasts'; @@ -51,7 +51,7 @@ class DummyAIClient implements AIClient { defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; constructor(private response = 'lorem ipsum') {} - async evaluate(_prompt: PromptMessage[]) { + async evaluate(_prompt: Prompt) { return this.response; } } diff --git a/app/src/lib/ai/service.ts b/app/src/lib/ai/service.ts index 2c7e105ec..cba51080a 100644 --- a/app/src/lib/ai/service.ts +++ b/app/src/lib/ai/service.ts @@ -11,8 +11,8 @@ import { type AIClient, AnthropicModelName, ModelKind, - type PromptMessage, - MessageRole + MessageRole, + type Prompt } from '$lib/ai/types'; import { splitMessage } from '$lib/utils/commitMessage'; import * as toasts from '$lib/utils/toasts'; @@ -45,13 +45,13 @@ type SummarizeCommitOpts = { hunks: Hunk[]; useEmojiStyle?: boolean; useBriefStyle?: boolean; - commitTemplate?: PromptMessage[]; + commitTemplate?: Prompt; userToken?: string; }; type SummarizeBranchOpts = { hunks: Hunk[]; - branchTemplate?: PromptMessage[]; + branchTemplate?: Prompt; userToken?: string; }; diff --git a/app/src/lib/ai/types.ts b/app/src/lib/ai/types.ts index 5f72bc189..1f3ab5da9 100644 --- a/app/src/lib/ai/types.ts +++ b/app/src/lib/ai/types.ts @@ -1,3 +1,5 @@ +import type { Persisted } from '$lib/persisted/persisted'; + export enum ModelKind { OpenAI = 'openai', Anthropic = 'anthropic', @@ -28,9 +30,22 @@ export interface PromptMessage { role: MessageRole; } -export interface AIClient { - evaluate(prompt: PromptMessage[]): Promise; +export type Prompt = PromptMessage[]; - defaultBranchTemplate: PromptMessage[]; - defaultCommitTemplate: PromptMessage[]; +export interface AIClient { + evaluate(prompt: Prompt): Promise; + + defaultBranchTemplate: Prompt; + defaultCommitTemplate: Prompt; +} + +export interface UserPrompt { + id: string; + name: string; + prompt: Prompt; +} + +export interface Prompts { + defaultPrompt: Prompt; + userPrompts: Persisted; } diff --git a/app/src/lib/components/AIPromptEdit/AIPromptEdit.svelte b/app/src/lib/components/AIPromptEdit/AIPromptEdit.svelte new file mode 100644 index 000000000..2d28ca6a3 --- /dev/null +++ b/app/src/lib/components/AIPromptEdit/AIPromptEdit.svelte @@ -0,0 +1,79 @@ + + +{#if prompts && $userPrompts} +
+

+ {promptUse == 'commits' ? 'Commit Message' : 'Branch Name'} +

+ +
+ +
+ + + {#each $userPrompts as prompt} + deletePrompt(e.detail.prompt)} + /> + {/each} +
+{/if} + + diff --git a/app/src/lib/components/AIPromptEdit/Content.svelte b/app/src/lib/components/AIPromptEdit/Content.svelte new file mode 100644 index 000000000..711946bd9 --- /dev/null +++ b/app/src/lib/components/AIPromptEdit/Content.svelte @@ -0,0 +1,227 @@ + + +
+
e.key === 'Enter' && toggleExpand()} + > + {#if !isInEditing} + +

{promptName}

+
+ +
+ {:else} + e.stopPropagation()} /> + {/if} +
+ + {#if expanded} +
+ {#each promptMessages as promptMessage, index} + { + errorMessages = errorMessages.filter((errorIndex) => errorIndex != index); + }} + isError={errorMessages.includes(index)} + /> + + {#if index % 2 == 0} +
+ {/if} + {/each} +
+ + {#if displayMode == 'writable'} +
+ {#if editing} + + + {:else} + + + {/if} +
+ {/if} + {/if} +
+ + diff --git a/app/src/lib/components/AIPromptEdit/DialogBubble.svelte b/app/src/lib/components/AIPromptEdit/DialogBubble.svelte new file mode 100644 index 000000000..a74a6bce8 --- /dev/null +++ b/app/src/lib/components/AIPromptEdit/DialogBubble.svelte @@ -0,0 +1,216 @@ + + +
+
+
+ {#if promptMessage.role == MessageRole.User} + + User + {:else} + + Assistant + {/if} +
+ + {#if editing} +