diff --git a/app/src/lib/ai/anthropicClient.ts b/app/src/lib/ai/anthropicClient.ts index 5f2bb3fde..07a6d5cb0 100644 --- a/app/src/lib/ai/anthropicClient.ts +++ b/app/src/lib/ai/anthropicClient.ts @@ -1,24 +1,21 @@ -import { - MessageRole, - type AIClient, - type AnthropicModelName, - type PromptMessage -} from '$lib/ai/types'; +import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts'; import { fetch, Body } from '@tauri-apps/api/http'; +import type { AIClient, AnthropicModelName, PromptMessage } from '$lib/ai/types'; type AnthropicAPIResponse = { content: { text: string }[] }; export class AnthropicAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private apiKey: string, private modelName: AnthropicModelName ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const body = Body.json({ - messages, + messages: prompt, max_tokens: 1024, model: this.modelName }); diff --git a/app/src/lib/ai/butlerClient.ts b/app/src/lib/ai/butlerClient.ts index 7fd0f3f5e..a228e83c8 100644 --- a/app/src/lib/ai/butlerClient.ts +++ b/app/src/lib/ai/butlerClient.ts @@ -1,19 +1,21 @@ -import { MessageRole, type ModelKind, type AIClient, type PromptMessage } from '$lib/ai/types'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; +import type { AIClient, ModelKind, PromptMessage } from '$lib/ai/types'; import type { HttpClient } from '$lib/backend/httpClient'; export class ButlerAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private cloud: HttpClient, private userToken: string, private modelKind: ModelKind ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { body: { - messages, + messages: prompt, max_tokens: 400, model_kind: this.modelKind }, diff --git a/app/src/lib/ai/ollamaClient.ts b/app/src/lib/ai/ollamaClient.ts new file mode 100644 index 000000000..25251d6bf --- /dev/null +++ b/app/src/lib/ai/ollamaClient.ts @@ -0,0 +1,165 @@ +import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; +import { MessageRole, type PromptMessage, type AIClient } from '$lib/ai/types'; +import { isNonEmptyObject } from '$lib/utils/typeguards'; + +export const DEFAULT_OLLAMA_ENDPOINT = 'http://127.0.0.1:11434'; +export const DEFAULT_OLLAMA_MODEL_NAME = 'llama3'; + +enum OllamaAPEndpoint { + Generate = 'api/generate', + Chat = 'api/chat', + Embed = 'api/embeddings' +} + +interface OllamaRequestOptions { + /** + * The temperature of the model. + * Increasing the temperature will make the model answer more creatively. (Default: 0.8) + */ + temperature: number; +} + +interface OllamaChatRequest { + model: string; + messages: PromptMessage[]; + stream: boolean; + format?: 'json'; + options?: OllamaRequestOptions; +} + +interface BaseOllamaMResponse { + created_at: string; + done: boolean; + model: string; +} + +interface OllamaChatResponse extends BaseOllamaMResponse { + message: PromptMessage; + done: true; +} + +interface OllamaChatMessageFormat { + result: string; +} + +const OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA = { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'], + additionalProperties: false +}; + +function isOllamaChatMessageFormat(message: unknown): message is OllamaChatMessageFormat { + if (!isNonEmptyObject(message)) { + return false; + } + + return typeof message.result === 'string'; +} + +function isOllamaChatResponse(response: unknown): response is OllamaChatResponse { + if (!isNonEmptyObject(response)) { + return false; + } + + return ( + isNonEmptyObject(response.message) && + typeof response.message.role == 'string' && + typeof response.message.content == 'string' + ); +} + +export class OllamaClient implements AIClient { + defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE; + + constructor( + private endpoint: string, + private modelName: string + ) {} + + async evaluate(prompt: PromptMessage[]) { + const messages = this.formatPrompt(prompt); + const response = await this.chat(messages); + const rawResponse = JSON.parse(response.message.content); + if (!isOllamaChatMessageFormat(rawResponse)) { + throw new Error('Invalid response: ' + response.message.content); + } + + return rawResponse.result; + } + + /** + * Appends a system message which instructs the model to respond using a particular JSON schema + * Modifies the prompt's Assistant messages to make use of the correct schema + */ + private formatPrompt(prompt: PromptMessage[]) { + const withFormattedResponses = prompt.map((promptMessage) => { + if (promptMessage.role == MessageRole.Assistant) { + return { + role: MessageRole.Assistant, + content: JSON.stringify({ result: promptMessage.content }) + }; + } else { + return promptMessage; + } + }); + + return [ + { + role: MessageRole.System, + content: `You are an expert in software development. Answer the given user prompts following the specified instructions. +Return your response in JSON and only use the following JSON schema: +${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}` + }, + ...withFormattedResponses + ]; + } + + /** + * Fetches the chat using the specified request. + * @param request - The OllamaChatRequest object containing the request details. + * @returns A Promise that resolves to the Response object. + */ + private async fetchChat(request: OllamaChatRequest): Promise { + const url = new URL(OllamaAPEndpoint.Chat, this.endpoint); + const result = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + return result; + } + + /** + * Sends a chat message to the LLM model and returns the response. + * + * @param messages - An array of LLMChatMessage objects representing the chat messages. + * @param options - Optional LLMRequestOptions object for specifying additional options. + * @throws Error if the response is invalid. + * @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model. + */ + private async chat( + messages: PromptMessage[], + options?: OllamaRequestOptions + ): Promise { + const result = await this.fetchChat({ + model: this.modelName, + stream: false, + messages, + options, + format: 'json' + }); + + const json = await result.json(); + if (!isOllamaChatResponse(json)) { + throw new Error('Invalid response\n' + JSON.stringify(json)); + } + + return json; + } +} diff --git a/app/src/lib/ai/openAIClient.ts b/app/src/lib/ai/openAIClient.ts index 6f7f6bd10..62ded36b8 100644 --- a/app/src/lib/ai/openAIClient.ts +++ b/app/src/lib/ai/openAIClient.ts @@ -1,24 +1,19 @@ -import { - MessageRole, - type OpenAIModelName, - type PromptMessage, - type AIClient -} from '$lib/ai/types'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; +import type { OpenAIModelName, PromptMessage, AIClient } from '$lib/ai/types'; import type OpenAI from 'openai'; export class OpenAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; + constructor( private modelName: OpenAIModelName, private openAI: OpenAI ) {} - async evaluate(prompt: string) { - const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }]; - + async evaluate(prompt: PromptMessage[]) { const response = await this.openAI.chat.completions.create({ - // @ts-expect-error There is a type mismatch where it seems to want a "name" paramater - // that isn't required https://github.com/openai/openai-openapi/issues/118#issuecomment-1847667988 - messages, + messages: prompt, model: this.modelName, max_tokens: 400 }); diff --git a/app/src/lib/ai/prompts.ts b/app/src/lib/ai/prompts.ts new file mode 100644 index 000000000..fd774126c --- /dev/null +++ b/app/src/lib/ai/prompts.ts @@ -0,0 +1,107 @@ +import { type PromptMessage, MessageRole } from '$lib/ai/types'; + +export const SHORT_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a commit message for my changes. +Only respond with the commit message. Don't give any notes. +Explain what were the changes and why the changes were done. +Focus the most important changes. +Use the present tense. +Use a semantic commit prefix. +Hard wrap lines at 72 characters. +Ensure the title is only 50 characters. +Do not start any lines with the hash symbol. +%{brief_style} +%{emoji_style} + +Here is my git diff: +%{diff}` + } +]; + +export const LONG_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a commit message for my changes. +Explain what were the changes and why the changes were done. +Focus the most important changes. +Use the present tense. +Use a semantic commit prefix. +Hard wrap lines at 72 characters. +Ensure the title is only 50 characters. +Do not start any lines with the hash symbol. +Only respond with the commit message. + +Here is my git diff: +diff --git a/src/utils/typing.ts b/src/utils/typing.ts +index 1cbfaa2..7aeebcf 100644 +--- a/src/utils/typing.ts ++++ b/src/utils/typing.ts +@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject + (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) + ); + } ++ ++export function isArrayOf( ++ something: unknown, ++ check: (value: unknown) => value is T ++): something is T[] { ++ return Array.isArray(something) && something.every(check); ++}` + }, + { + role: MessageRole.Assistant, + content: `Typing utilities: Check for array of type + +Added an utility function to check whether a given value is an array of a specific type.` + }, + ...SHORT_DEFAULT_COMMIT_TEMPLATE +]; + +export const SHORT_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a branch name for my changes. +A branch name represent a brief description of the changes in the diff (branch). +Branch names should contain no whitespace and instead use dashes to separate words. +Branch names should contain a maximum of 5 words. +Only respond with the branch name. + +Here is my git diff: +%{diff}` + } +]; + +export const LONG_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [ + { + role: MessageRole.User, + content: `Please could you write a branch name for my changes. +A branch name represent a brief description of the changes in the diff (branch). +Branch names should contain no whitespace and instead use dashes to separate words. +Branch names should contain a maximum of 5 words. +Only respond with the branch name. + +Here is my git diff: +diff --git a/src/utils/typing.ts b/src/utils/typing.ts +index 1cbfaa2..7aeebcf 100644 +--- a/src/utils/typing.ts ++++ b/src/utils/typing.ts +@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject + (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) + ); + } ++ ++export function isArrayOf( ++ something: unknown, ++ check: (value: unknown) => value is T ++): something is T[] { ++ return Array.isArray(something) && something.every(check); ++}` + }, + { + role: MessageRole.Assistant, + content: `utils-typing-is-array-of-type` + }, + ...SHORT_DEFAULT_BRANCH_TEMPLATE +]; diff --git a/app/src/lib/ai/service.test.ts b/app/src/lib/ai/service.test.ts index 88c7e5fab..51611550f 100644 --- a/app/src/lib/ai/service.test.ts +++ b/app/src/lib/ai/service.test.ts @@ -1,8 +1,15 @@ import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { ButlerAIClient } from '$lib/ai/butlerClient'; import { OpenAIClient } from '$lib/ai/openAIClient'; +import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { AIService, GitAIConfigKey, KeyOption, buildDiff } from '$lib/ai/service'; -import { AnthropicModelName, ModelKind, OpenAIModelName, type AIClient } from '$lib/ai/types'; +import { + AnthropicModelName, + ModelKind, + OpenAIModelName, + type AIClient, + type PromptMessage +} from '$lib/ai/types'; import { HttpClient } from '$lib/backend/httpClient'; import * as toasts from '$lib/utils/toasts'; import { Hunk } from '$lib/vbranches/types'; @@ -40,9 +47,11 @@ const fetchMock = vi.fn(); const cloud = new HttpClient(fetchMock); class DummyAIClient implements AIClient { + defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; constructor(private response = 'lorem ipsum') {} - async evaluate(_prompt: string) { + async evaluate(_prompt: PromptMessage[]) { return this.response; } } diff --git a/app/src/lib/ai/service.ts b/app/src/lib/ai/service.ts index 05ff627e6..2c7e105ec 100644 --- a/app/src/lib/ai/service.ts +++ b/app/src/lib/ai/service.ts @@ -1,7 +1,19 @@ import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { ButlerAIClient } from '$lib/ai/butlerClient'; +import { + DEFAULT_OLLAMA_ENDPOINT, + DEFAULT_OLLAMA_MODEL_NAME, + OllamaClient +} from '$lib/ai/ollamaClient'; import { OpenAIClient } from '$lib/ai/openAIClient'; -import { OpenAIModelName, type AIClient, AnthropicModelName } from '$lib/ai/types'; +import { + OpenAIModelName, + type AIClient, + AnthropicModelName, + ModelKind, + type PromptMessage, + MessageRole +} from '$lib/ai/types'; import { splitMessage } from '$lib/utils/commitMessage'; import * as toasts from '$lib/utils/toasts'; import OpenAI from 'openai'; @@ -11,39 +23,6 @@ import type { Hunk } from '$lib/vbranches/types'; const maxDiffLengthLimitForAPI = 5000; -const defaultCommitTemplate = ` -Please could you write a commit message for my changes. -Explain what were the changes and why the changes were done. -Focus the most important changes. -Use the present tense. -Use a semantic commit prefix. -Hard wrap lines at 72 characters. -Ensure the title is only 50 characters. -Do not start any lines with the hash symbol. -Only respond with the commit message. -%{brief_style} -%{emoji_style} - -Here is my git diff: -%{diff} -`; - -const defaultBranchTemplate = ` -Please could you write a branch name for my changes. -A branch name represent a brief description of the changes in the diff (branch). -Branch names should contain no whitespace and instead use dashes to separate words. -Branch names should contain a maximum of 5 words. -Only respond with the branch name. - -Here is my git diff: -%{diff} -`; - -export enum ModelKind { - OpenAI = 'openai', - Anthropic = 'anthropic' -} - export enum KeyOption { BringYourOwn = 'bringYourOwn', ButlerAPI = 'butlerAPI' @@ -57,20 +36,22 @@ export enum GitAIConfigKey { AnthropicKeyOption = 'gitbutler.aiAnthropicKeyOption', AnthropicModelName = 'gitbutler.aiAnthropicModelName', AnthropicKey = 'gitbutler.aiAnthropicKey', - DiffLengthLimit = 'gitbutler.diffLengthLimit' + DiffLengthLimit = 'gitbutler.diffLengthLimit', + OllamaEndpoint = 'gitbutler.aiOllamaEndpoint', + OllamaModelName = 'gitbutler.aiOllamaModelName' } type SummarizeCommitOpts = { hunks: Hunk[]; useEmojiStyle?: boolean; useBriefStyle?: boolean; - commitTemplate?: string; + commitTemplate?: PromptMessage[]; userToken?: string; }; type SummarizeBranchOpts = { hunks: Hunk[]; - branchTemplate?: string; + branchTemplate?: PromptMessage[]; userToken?: string; }; @@ -84,7 +65,7 @@ export function buildDiff(hunks: Hunk[], limit: number) { function shuffle(items: T[]): T[] { return items .map((item) => ({ item, value: Math.random() })) - .sort() + .sort(({ value: a }, { value: b }) => a - b) .map((item) => item.item); } @@ -159,6 +140,20 @@ export class AIService { } } + async getOllamaEndpoint() { + return await this.gitConfig.getWithDefault( + GitAIConfigKey.OllamaEndpoint, + DEFAULT_OLLAMA_ENDPOINT + ); + } + + async getOllamaModelName() { + return await this.gitConfig.getWithDefault( + GitAIConfigKey.OllamaModelName, + DEFAULT_OLLAMA_MODEL_NAME + ); + } + async usingGitButlerAPI() { const modelKind = await this.getModelKind(); const openAIKeyOption = await this.getOpenAIKeyOption(); @@ -176,13 +171,19 @@ export class AIService { const modelKind = await this.getModelKind(); const openAIKey = await this.getOpenAIKey(); const anthropicKey = await this.getAnthropicKey(); + const ollamaEndpoint = await this.getOllamaEndpoint(); + const ollamaModelName = await this.getOllamaModelName(); if (await this.usingGitButlerAPI()) return !!userToken; const openAIActiveAndKeyProvided = modelKind == ModelKind.OpenAI && !!openAIKey; const anthropicActiveAndKeyProvided = modelKind == ModelKind.Anthropic && !!anthropicKey; + const ollamaActiveAndEndpointProvided = + modelKind == ModelKind.Ollama && !!ollamaEndpoint && !!ollamaModelName; - return openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided; + return ( + openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided + ); } // This optionally returns a summarizer. There are a few conditions for how this may occur @@ -199,6 +200,12 @@ export class AIService { return new ButlerAIClient(this.cloud, userToken, modelKind); } + if (modelKind == ModelKind.Ollama) { + const ollamaEndpoint = await this.getOllamaEndpoint(); + const ollamaModelName = await this.getOllamaModelName(); + return new OllamaClient(ollamaEndpoint, ollamaModelName); + } + if (modelKind == ModelKind.OpenAI) { const openAIModelName = await this.getOpenAIModleName(); const openAIKey = await this.getOpenAIKey(); @@ -233,24 +240,37 @@ export class AIService { hunks, useEmojiStyle = false, useBriefStyle = false, - commitTemplate = defaultCommitTemplate, + commitTemplate, userToken }: SummarizeCommitOpts) { const aiClient = await this.buildClient(userToken); if (!aiClient) return; const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); - let prompt = commitTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); + const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate; - const briefPart = useBriefStyle - ? 'The commit message must be only one sentence and as short as possible.' - : ''; - prompt = prompt.replaceAll('%{brief_style}', briefPart); + const prompt = defaultedCommitTemplate.map((promptMessage) => { + if (promptMessage.role != MessageRole.User) { + return promptMessage; + } - const emojiPart = useEmojiStyle - ? 'Make use of GitMoji in the title prefix.' - : "Don't use any emoji."; - prompt = prompt.replaceAll('%{emoji_style}', emojiPart); + let content = promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); + + const briefPart = useBriefStyle + ? 'The commit message must be only one sentence and as short as possible.' + : ''; + content = content.replaceAll('%{brief_style}', briefPart); + + const emojiPart = useEmojiStyle + ? 'Make use of GitMoji in the title prefix.' + : "Don't use any emoji."; + content = content.replaceAll('%{emoji_style}', emojiPart); + + return { + role: MessageRole.User, + content + }; + }); let message = await aiClient.evaluate(prompt); @@ -262,16 +282,23 @@ export class AIService { return description ? `${title}\n\n${description}` : title; } - async summarizeBranch({ - hunks, - branchTemplate = defaultBranchTemplate, - userToken = undefined - }: SummarizeBranchOpts) { + async summarizeBranch({ hunks, branchTemplate, userToken = undefined }: SummarizeBranchOpts) { const aiClient = await this.buildClient(userToken); if (!aiClient) return; const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); - const prompt = branchTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); + const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate; + const prompt = defaultedBranchTemplate.map((promptMessage) => { + if (promptMessage.role != MessageRole.User) { + return promptMessage; + } + + return { + role: MessageRole.User, + content: promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)) + }; + }); + const message = await aiClient.evaluate(prompt); return message.replaceAll(' ', '-').replaceAll('\n', '-'); } diff --git a/app/src/lib/ai/types.ts b/app/src/lib/ai/types.ts index cc2d71135..55d450beb 100644 --- a/app/src/lib/ai/types.ts +++ b/app/src/lib/ai/types.ts @@ -1,6 +1,7 @@ export enum ModelKind { OpenAI = 'openai', - Anthropic = 'anthropic' + Anthropic = 'anthropic', + Ollama = 'ollama' } export enum OpenAIModelName { @@ -16,8 +17,9 @@ export enum AnthropicModelName { } export enum MessageRole { + System = 'system', User = 'user', - Assistant = 'assisstant' + Assistant = 'assistant' } export interface PromptMessage { @@ -26,5 +28,8 @@ export interface PromptMessage { } export interface AIClient { - evaluate(prompt: string): Promise; + evaluate(prompt: PromptMessage[]): Promise; + + defaultBranchTemplate: PromptMessage[]; + defaultCommitTemplate: PromptMessage[]; } diff --git a/app/src/lib/utils/typeguards.ts b/app/src/lib/utils/typeguards.ts index 76e9c7c3d..90132d3d2 100644 --- a/app/src/lib/utils/typeguards.ts +++ b/app/src/lib/utils/typeguards.ts @@ -5,3 +5,19 @@ export function isDefined(file: T | undefined | null): file is T { export function notNull(file: T | undefined | null): file is T { return file !== null; } + +export type UnknownObject = Record; + +/** + * Checks if the provided value is a non-empty object. + * @param something - The value to be checked. + * @returns A boolean indicating whether the value is a non-empty object. + */ +export function isNonEmptyObject(something: unknown): something is UnknownObject { + return ( + typeof something === 'object' && + something !== null && + !Array.isArray(something) && + (Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0) + ); +} diff --git a/app/src/routes/settings/ai/+page.svelte b/app/src/routes/settings/ai/+page.svelte index 6a53f0e80..5bdd571d2 100644 --- a/app/src/routes/settings/ai/+page.svelte +++ b/app/src/routes/settings/ai/+page.svelte @@ -1,6 +1,6 @@