diff --git a/app/src/lib/ai/anthropicClient.ts b/app/src/lib/ai/anthropicClient.ts index 5f2bb3fde..07a6d5cb0 100644 --- a/app/src/lib/ai/anthropicClient.ts +++ b/app/src/lib/ai/anthropicClient.ts @@ -1,24 +1,21 @@ -import { - MessageRole, - type AIClient, - type AnthropicModelName, - type PromptMessage -} from '$lib/ai/types'; +import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts'; import { fetch, Body } from '@tauri-apps/api/http'; +import type { AIClient, AnthropicModelName, PromptMessage } from '$lib/ai/types'; type AnthropicAPIResponse = { content: { text: string }[] }; export class AnthropicAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private apiKey: string, private modelName: AnthropicModelName ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const body = Body.json({ - messages, + messages: prompt, max_tokens: 1024, model: this.modelName }); diff --git a/app/src/lib/ai/butlerClient.ts b/app/src/lib/ai/butlerClient.ts index 7fd0f3f5e..a228e83c8 100644 --- a/app/src/lib/ai/butlerClient.ts +++ b/app/src/lib/ai/butlerClient.ts @@ -1,19 +1,21 @@ -import { MessageRole, type ModelKind, type AIClient, type PromptMessage } from '$lib/ai/types'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; +import type { AIClient, ModelKind, PromptMessage } from '$lib/ai/types'; import type { HttpClient } from '$lib/backend/httpClient'; export class ButlerAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private cloud: HttpClient, private userToken: string, private modelKind: ModelKind ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { body: { - messages, + messages: prompt, max_tokens: 400, model_kind: this.modelKind }, diff --git a/app/src/lib/ai/ollamaClient.ts b/app/src/lib/ai/ollamaClient.ts index cbc4f8ad2..25251d6bf 100644 --- a/app/src/lib/ai/ollamaClient.ts +++ b/app/src/lib/ai/ollamaClient.ts @@ -1,59 +1,10 @@ +import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { MessageRole, type PromptMessage, type AIClient } from '$lib/ai/types'; import { isNonEmptyObject } from '$lib/utils/typeguards'; export const DEFAULT_OLLAMA_ENDPOINT = 'http://127.0.0.1:11434'; export const DEFAULT_OLLAMA_MODEL_NAME = 'llama3'; -const PROMT_EXAMPLE_COMMIT_MESSAGE_GENERATION = `Please could you write a commit message for my changes. -Explain what were the changes and why the changes were done. -Focus the most important changes. -Use the present tense. -Use a semantic commit prefix. -Hard wrap lines at 72 characters. -Ensure the title is only 50 characters. -Do not start any lines with the hash symbol. -Only respond with the commit message. - -Here is my git diff: -diff --git a/src/utils/typing.ts b/src/utils/typing.ts -index 1cbfaa2..7aeebcf 100644 ---- a/src/utils/typing.ts -+++ b/src/utils/typing.ts -@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject - (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) - ); - } -+ -+export function isArrayOf( -+ something: unknown, -+ check: (value: unknown) => value is T -+): something is T[] { -+ return Array.isArray(something) && something.every(check); -+}`; - -const PROMPT_EXAMPLE_BRANCH_NAME_GENERATION = `Please could you write a branch name for my changes. -A branch name represent a brief description of the changes in the diff (branch). -Branch names should contain no whitespace and instead use dashes to separate words. -Branch names should contain a maximum of 5 words. -Only respond with the branch name. - -Here is my git diff: -diff --git a/src/utils/typing.ts b/src/utils/typing.ts -index 1cbfaa2..7aeebcf 100644 ---- a/src/utils/typing.ts -+++ b/src/utils/typing.ts -@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject - (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) - ); - } -+ -+export function isArrayOf( -+ something: unknown, -+ check: (value: unknown) => value is T -+): something is T[] { -+ return Array.isArray(something) && something.every(check); -+}`; - enum OllamaAPEndpoint { Generate = 'api/generate', Chat = 'api/chat', @@ -91,7 +42,7 @@ interface OllamaChatMessageFormat { result: string; } -const OllamaChatMessageFormatSchema = { +const OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA = { type: 'object', properties: { result: { type: 'string' } @@ -121,32 +72,16 @@ function isOllamaChatResponse(response: unknown): response is OllamaChatResponse } export class OllamaClient implements AIClient { + defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE; + constructor( private endpoint: string, private modelName: string ) {} - async branchName(prompt: string): Promise { - const messages: PromptMessage[] = [ - { - role: MessageRole.System, - content: `You are an expert in software development. Answer the given user prompts following the specified instructions. -Return your response in JSON and only use the following JSON schema: -${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}` - }, - { - role: MessageRole.User, - content: PROMPT_EXAMPLE_BRANCH_NAME_GENERATION - }, - { - role: MessageRole.Assistant, - content: JSON.stringify({ - result: `utils-typing-is-array-of-type` - }) - }, - { role: MessageRole.User, content: prompt } - ]; - + async evaluate(prompt: PromptMessage[]) { + const messages = this.formatPrompt(prompt); const response = await this.chat(messages); const rawResponse = JSON.parse(response.message.content); if (!isOllamaChatMessageFormat(rawResponse)) { @@ -156,36 +91,31 @@ ${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}` return rawResponse.result; } - async evaluate(prompt: string) { - const messages: PromptMessage[] = [ + /** + * Appends a system message which instructs the model to respond using a particular JSON schema + * Modifies the prompt's Assistant messages to make use of the correct schema + */ + private formatPrompt(prompt: PromptMessage[]) { + const withFormattedResponses = prompt.map((promptMessage) => { + if (promptMessage.role == MessageRole.Assistant) { + return { + role: MessageRole.Assistant, + content: JSON.stringify({ result: promptMessage.content }) + }; + } else { + return promptMessage; + } + }); + + return [ { role: MessageRole.System, content: `You are an expert in software development. Answer the given user prompts following the specified instructions. Return your response in JSON and only use the following JSON schema: -${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}` +${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}` }, - { - role: MessageRole.User, - content: PROMT_EXAMPLE_COMMIT_MESSAGE_GENERATION - }, - { - role: MessageRole.Assistant, - content: JSON.stringify({ - result: `Typing utilities: Check for array of type - -Added an utility function to check whether a given value is an array of a specific type.` - }) - }, - { role: MessageRole.User, content: prompt } + ...withFormattedResponses ]; - - const response = await this.chat(messages); - const rawResponse = JSON.parse(response.message.content); - if (!isOllamaChatMessageFormat(rawResponse)) { - throw new Error('Invalid response: ' + response.message.content); - } - - return rawResponse.result; } /** diff --git a/app/src/lib/ai/openAIClient.ts b/app/src/lib/ai/openAIClient.ts index 706a347ef..62ded36b8 100644 --- a/app/src/lib/ai/openAIClient.ts +++ b/app/src/lib/ai/openAIClient.ts @@ -1,22 +1,19 @@ -import { - MessageRole, - type OpenAIModelName, - type PromptMessage, - type AIClient -} from '$lib/ai/types'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; +import type { OpenAIModelName, PromptMessage, AIClient } from '$lib/ai/types'; import type OpenAI from 'openai'; export class OpenAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private modelName: OpenAIModelName, private openAI: OpenAI ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const response = await this.openAI.chat.completions.create({ - messages, + messages: prompt, model: this.modelName, max_tokens: 400 }); diff --git a/app/src/lib/ai/prompts.ts b/app/src/lib/ai/prompts.ts new file mode 100644 index 000000000..fd774126c --- /dev/null +++ b/app/src/lib/ai/prompts.ts @@ -0,0 +1,107 @@ +import { type PromptMessage, MessageRole } from '$lib/ai/types'; + +export const SHORT_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a commit message for my changes. +Only respond with the commit message. Don't give any notes. +Explain what were the changes and why the changes were done. +Focus the most important changes. +Use the present tense. +Use a semantic commit prefix. +Hard wrap lines at 72 characters. +Ensure the title is only 50 characters. +Do not start any lines with the hash symbol. +%{brief_style} +%{emoji_style} + +Here is my git diff: +%{diff}` + } +]; + +export const LONG_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a commit message for my changes. +Explain what were the changes and why the changes were done. +Focus the most important changes. +Use the present tense. +Use a semantic commit prefix. +Hard wrap lines at 72 characters. +Ensure the title is only 50 characters. +Do not start any lines with the hash symbol. +Only respond with the commit message. + +Here is my git diff: +diff --git a/src/utils/typing.ts b/src/utils/typing.ts +index 1cbfaa2..7aeebcf 100644 +--- a/src/utils/typing.ts ++++ b/src/utils/typing.ts +@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject + (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) + ); + } ++ ++export function isArrayOf( ++ something: unknown, ++ check: (value: unknown) => value is T ++): something is T[] { ++ return Array.isArray(something) && something.every(check); ++}` + }, + { + role: MessageRole.Assistant, + content: `Typing utilities: Check for array of type + +Added an utility function to check whether a given value is an array of a specific type.` + }, + ...SHORT_DEFAULT_COMMIT_TEMPLATE +]; + +export const SHORT_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a branch name for my changes. +A branch name represent a brief description of the changes in the diff (branch). +Branch names should contain no whitespace and instead use dashes to separate words. +Branch names should contain a maximum of 5 words. +Only respond with the branch name. + +Here is my git diff: +%{diff}` + } +]; + +export const LONG_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a branch name for my changes. +A branch name represent a brief description of the changes in the diff (branch). +Branch names should contain no whitespace and instead use dashes to separate words. +Branch names should contain a maximum of 5 words. +Only respond with the branch name. + +Here is my git diff: +diff --git a/src/utils/typing.ts b/src/utils/typing.ts +index 1cbfaa2..7aeebcf 100644 +--- a/src/utils/typing.ts ++++ b/src/utils/typing.ts +@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject + (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) + ); + } ++ ++export function isArrayOf( ++ something: unknown, ++ check: (value: unknown) => value is T ++): something is T[] { ++ return Array.isArray(something) && something.every(check); ++}` + }, + { + role: MessageRole.Assistant, + content: `utils-typing-is-array-of-type` + }, + ...SHORT_DEFAULT_BRANCH_TEMPLATE +]; diff --git a/app/src/lib/ai/service.test.ts b/app/src/lib/ai/service.test.ts index 88c7e5fab..51611550f 100644 --- a/app/src/lib/ai/service.test.ts +++ b/app/src/lib/ai/service.test.ts @@ -1,8 +1,15 @@ import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { ButlerAIClient } from '$lib/ai/butlerClient'; import { OpenAIClient } from '$lib/ai/openAIClient'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { AIService, GitAIConfigKey, KeyOption, buildDiff } from '$lib/ai/service'; -import { AnthropicModelName, ModelKind, OpenAIModelName, type AIClient } from '$lib/ai/types'; +import { + AnthropicModelName, + ModelKind, + OpenAIModelName, + type AIClient, + type PromptMessage +} from '$lib/ai/types'; import { HttpClient } from '$lib/backend/httpClient'; import * as toasts from '$lib/utils/toasts'; import { Hunk } from '$lib/vbranches/types'; @@ -40,9 +47,11 @@ const fetchMock = vi.fn(); const cloud = new HttpClient(fetchMock); class DummyAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; constructor(private response = 'lorem ipsum') {} - async evaluate(_prompt: string) { + async evaluate(_prompt: PromptMessage[]) { return this.response; } } diff --git a/app/src/lib/ai/service.ts b/app/src/lib/ai/service.ts index 4bcdf2cbc..2c7e105ec 100644 --- a/app/src/lib/ai/service.ts +++ b/app/src/lib/ai/service.ts @@ -6,7 +6,14 @@ import { OllamaClient } from '$lib/ai/ollamaClient'; import { OpenAIClient } from '$lib/ai/openAIClient'; -import { OpenAIModelName, type AIClient, AnthropicModelName, ModelKind } from '$lib/ai/types'; +import { + OpenAIModelName, + type AIClient, + AnthropicModelName, + ModelKind, + type PromptMessage, + MessageRole +} from '$lib/ai/types'; import { splitMessage } from '$lib/utils/commitMessage'; import * as toasts from '$lib/utils/toasts'; import OpenAI from 'openai'; @@ -16,34 +23,6 @@ import type { Hunk } from '$lib/vbranches/types'; const maxDiffLengthLimitForAPI = 5000; -const defaultCommitTemplate = ` -Please could you write a commit message for my changes. -Only respond with the commit message. Don't give any notes. -Explain what were the changes and why the changes were done. -Focus the most important changes. -Use the present tense. -Use a semantic commit prefix. -Hard wrap lines at 72 characters. -Ensure the title is only 50 characters. -Do not start any lines with the hash symbol. -%{brief_style} -%{emoji_style} - -Here is my git diff: -%{diff} -`; - -const defaultBranchTemplate = ` -Please could you write a branch name for my changes. -A branch name represent a brief description of the changes in the diff (branch). -Branch names should contain no whitespace and instead use dashes to separate words. -Branch names should contain a maximum of 5 words. -Only respond with the branch name. - -Here is my git diff: -%{diff} -`; - export enum KeyOption { BringYourOwn = 'bringYourOwn', ButlerAPI = 'butlerAPI' @@ -66,13 +45,13 @@ type SummarizeCommitOpts = { hunks: Hunk[]; useEmojiStyle?: boolean; useBriefStyle?: boolean; - commitTemplate?: string; + commitTemplate?: PromptMessage[]; userToken?: string; }; type SummarizeBranchOpts = { hunks: Hunk[]; - branchTemplate?: string; + branchTemplate?: PromptMessage[]; userToken?: string; }; @@ -261,24 +240,37 @@ export class AIService { hunks, useEmojiStyle = false, useBriefStyle = false, - commitTemplate = defaultCommitTemplate, + commitTemplate, userToken }: SummarizeCommitOpts) { const aiClient = await this.buildClient(userToken); if (!aiClient) return; const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); - let prompt = commitTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); + const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate; - const briefPart = useBriefStyle - ? 'The commit message must be only one sentence and as short as possible.' - : ''; - prompt = prompt.replaceAll('%{brief_style}', briefPart); + const prompt = defaultedCommitTemplate.map((promptMessage) => { + if (promptMessage.role != MessageRole.User) { + return promptMessage; + } - const emojiPart = useEmojiStyle - ? 'Make use of GitMoji in the title prefix.' - : "Don't use any emoji."; - prompt = prompt.replaceAll('%{emoji_style}', emojiPart); + let content = promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); + + const briefPart = useBriefStyle + ? 'The commit message must be only one sentence and as short as possible.' + : ''; + content = content.replaceAll('%{brief_style}', briefPart); + + const emojiPart = useEmojiStyle + ? 'Make use of GitMoji in the title prefix.' + : "Don't use any emoji."; + content = content.replaceAll('%{emoji_style}', emojiPart); + + return { + role: MessageRole.User, + content + }; + }); let message = await aiClient.evaluate(prompt); @@ -290,18 +282,24 @@ export class AIService { return description ? `${title}\n\n${description}` : title; } - async summarizeBranch({ - hunks, - branchTemplate = defaultBranchTemplate, - userToken = undefined - }: SummarizeBranchOpts) { + async summarizeBranch({ hunks, branchTemplate, userToken = undefined }: SummarizeBranchOpts) { const aiClient = await this.buildClient(userToken); if (!aiClient) return; const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); - const prompt = branchTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); - const branchNamePromise = aiClient.evaluate(prompt); - const message = await branchNamePromise; + const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate; + const prompt = defaultedBranchTemplate.map((promptMessage) => { + if (promptMessage.role != MessageRole.User) { + return promptMessage; + } + + return { + role: MessageRole.User, + content: promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)) + }; + }); + + const message = await aiClient.evaluate(prompt); return message.replaceAll(' ', '-').replaceAll('\n', '-'); } } diff --git a/app/src/lib/ai/types.ts b/app/src/lib/ai/types.ts index 4999fe5d7..55d450beb 100644 --- a/app/src/lib/ai/types.ts +++ b/app/src/lib/ai/types.ts @@ -28,5 +28,8 @@ export interface PromptMessage { } export interface AIClient { - evaluate(prompt: string): Promise; + evaluate(prompt: PromptMessage[]): Promise; + + defaultBranchTemplate: PromptMessage[]; + defaultCommitTemplate: PromptMessage[]; }