Merge pull request #3741 from estib-vega/master

feat: Add support for Ollama
This commit is contained in:
Caleb Owens 2024-05-16 07:00:21 +01:00 committed by GitHub
commit 27535ca409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 442 additions and 95 deletions

View File

@ -1,24 +1,21 @@
import { import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts';
MessageRole,
type AIClient,
type AnthropicModelName,
type PromptMessage
} from '$lib/ai/types';
import { fetch, Body } from '@tauri-apps/api/http'; import { fetch, Body } from '@tauri-apps/api/http';
import type { AIClient, AnthropicModelName, PromptMessage } from '$lib/ai/types';
type AnthropicAPIResponse = { content: { text: string }[] }; type AnthropicAPIResponse = { content: { text: string }[] };
export class AnthropicAIClient implements AIClient { export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor( constructor(
private apiKey: string, private apiKey: string,
private modelName: AnthropicModelName private modelName: AnthropicModelName
) {} ) {}
async evaluate(prompt: string) { async evaluate(prompt: PromptMessage[]) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
const body = Body.json({ const body = Body.json({
messages, messages: prompt,
max_tokens: 1024, max_tokens: 1024,
model: this.modelName model: this.modelName
}); });

View File

@ -1,19 +1,21 @@
import { MessageRole, type ModelKind, type AIClient, type PromptMessage } from '$lib/ai/types'; import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import type { AIClient, ModelKind, PromptMessage } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient'; import type { HttpClient } from '$lib/backend/httpClient';
export class ButlerAIClient implements AIClient { export class ButlerAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor( constructor(
private cloud: HttpClient, private cloud: HttpClient,
private userToken: string, private userToken: string,
private modelKind: ModelKind private modelKind: ModelKind
) {} ) {}
async evaluate(prompt: string) { async evaluate(prompt: PromptMessage[]) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
body: { body: {
messages, messages: prompt,
max_tokens: 400, max_tokens: 400,
model_kind: this.modelKind model_kind: this.modelKind
}, },

View File

@ -0,0 +1,165 @@
import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient } from '$lib/ai/types';
import { isNonEmptyObject } from '$lib/utils/typeguards';
export const DEFAULT_OLLAMA_ENDPOINT = 'http://127.0.0.1:11434';
export const DEFAULT_OLLAMA_MODEL_NAME = 'llama3';
enum OllamaAPEndpoint {
Generate = 'api/generate',
Chat = 'api/chat',
Embed = 'api/embeddings'
}
interface OllamaRequestOptions {
/**
* The temperature of the model.
* Increasing the temperature will make the model answer more creatively. (Default: 0.8)
*/
temperature: number;
}
interface OllamaChatRequest {
model: string;
messages: PromptMessage[];
stream: boolean;
format?: 'json';
options?: OllamaRequestOptions;
}
interface BaseOllamaMResponse {
created_at: string;
done: boolean;
model: string;
}
interface OllamaChatResponse extends BaseOllamaMResponse {
message: PromptMessage;
done: true;
}
interface OllamaChatMessageFormat {
result: string;
}
const OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA = {
type: 'object',
properties: {
result: { type: 'string' }
},
required: ['result'],
additionalProperties: false
};
function isOllamaChatMessageFormat(message: unknown): message is OllamaChatMessageFormat {
if (!isNonEmptyObject(message)) {
return false;
}
return typeof message.result === 'string';
}
function isOllamaChatResponse(response: unknown): response is OllamaChatResponse {
if (!isNonEmptyObject(response)) {
return false;
}
return (
isNonEmptyObject(response.message) &&
typeof response.message.role == 'string' &&
typeof response.message.content == 'string'
);
}
export class OllamaClient implements AIClient {
defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE;
constructor(
private endpoint: string,
private modelName: string
) {}
async evaluate(prompt: PromptMessage[]) {
const messages = this.formatPrompt(prompt);
const response = await this.chat(messages);
const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
throw new Error('Invalid response: ' + response.message.content);
}
return rawResponse.result;
}
/**
* Appends a system message which instructs the model to respond using a particular JSON schema
* Modifies the prompt's Assistant messages to make use of the correct schema
*/
private formatPrompt(prompt: PromptMessage[]) {
const withFormattedResponses = prompt.map((promptMessage) => {
if (promptMessage.role == MessageRole.Assistant) {
return {
role: MessageRole.Assistant,
content: JSON.stringify({ result: promptMessage.content })
};
} else {
return promptMessage;
}
});
return [
{
role: MessageRole.System,
content: `You are an expert in software development. Answer the given user prompts following the specified instructions.
Return your response in JSON and only use the following JSON schema:
${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
},
...withFormattedResponses
];
}
/**
* Fetches the chat using the specified request.
* @param request - The OllamaChatRequest object containing the request details.
* @returns A Promise that resolves to the Response object.
*/
private async fetchChat(request: OllamaChatRequest): Promise<Response> {
const url = new URL(OllamaAPEndpoint.Chat, this.endpoint);
const result = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(request)
});
return result;
}
/**
* Sends a chat message to the LLM model and returns the response.
*
* @param messages - An array of LLMChatMessage objects representing the chat messages.
* @param options - Optional LLMRequestOptions object for specifying additional options.
* @throws Error if the response is invalid.
* @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model.
*/
private async chat(
messages: PromptMessage[],
options?: OllamaRequestOptions
): Promise<OllamaChatResponse> {
const result = await this.fetchChat({
model: this.modelName,
stream: false,
messages,
options,
format: 'json'
});
const json = await result.json();
if (!isOllamaChatResponse(json)) {
throw new Error('Invalid response\n' + JSON.stringify(json));
}
return json;
}
}

View File

@ -1,24 +1,19 @@
import { import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
MessageRole, import type { OpenAIModelName, PromptMessage, AIClient } from '$lib/ai/types';
type OpenAIModelName,
type PromptMessage,
type AIClient
} from '$lib/ai/types';
import type OpenAI from 'openai'; import type OpenAI from 'openai';
export class OpenAIClient implements AIClient { export class OpenAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor( constructor(
private modelName: OpenAIModelName, private modelName: OpenAIModelName,
private openAI: OpenAI private openAI: OpenAI
) {} ) {}
async evaluate(prompt: string) { async evaluate(prompt: PromptMessage[]) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
const response = await this.openAI.chat.completions.create({ const response = await this.openAI.chat.completions.create({
// @ts-expect-error There is a type mismatch where it seems to want a "name" paramater messages: prompt,
// that isn't required https://github.com/openai/openai-openapi/issues/118#issuecomment-1847667988
messages,
model: this.modelName, model: this.modelName,
max_tokens: 400 max_tokens: 400
}); });

107
app/src/lib/ai/prompts.ts Normal file
View File

@ -0,0 +1,107 @@
import { type PromptMessage, MessageRole } from '$lib/ai/types';
export const SHORT_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a commit message for my changes.
Only respond with the commit message. Don't give any notes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
%{brief_style}
%{emoji_style}
Here is my git diff:
%{diff}`
}
];
export const LONG_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a commit message for my changes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
Only respond with the commit message.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`
},
{
role: MessageRole.Assistant,
content: `Typing utilities: Check for array of type
Added an utility function to check whether a given value is an array of a specific type.`
},
...SHORT_DEFAULT_COMMIT_TEMPLATE
];
export const SHORT_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
%{diff}`
}
];
export const LONG_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`
},
{
role: MessageRole.Assistant,
content: `utils-typing-is-array-of-type`
},
...SHORT_DEFAULT_BRANCH_TEMPLATE
];

View File

@ -1,8 +1,15 @@
import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { AnthropicAIClient } from '$lib/ai/anthropicClient';
import { ButlerAIClient } from '$lib/ai/butlerClient'; import { ButlerAIClient } from '$lib/ai/butlerClient';
import { OpenAIClient } from '$lib/ai/openAIClient'; import { OpenAIClient } from '$lib/ai/openAIClient';
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { AIService, GitAIConfigKey, KeyOption, buildDiff } from '$lib/ai/service'; import { AIService, GitAIConfigKey, KeyOption, buildDiff } from '$lib/ai/service';
import { AnthropicModelName, ModelKind, OpenAIModelName, type AIClient } from '$lib/ai/types'; import {
AnthropicModelName,
ModelKind,
OpenAIModelName,
type AIClient,
type PromptMessage
} from '$lib/ai/types';
import { HttpClient } from '$lib/backend/httpClient'; import { HttpClient } from '$lib/backend/httpClient';
import * as toasts from '$lib/utils/toasts'; import * as toasts from '$lib/utils/toasts';
import { Hunk } from '$lib/vbranches/types'; import { Hunk } from '$lib/vbranches/types';
@ -40,9 +47,11 @@ const fetchMock = vi.fn();
const cloud = new HttpClient(fetchMock); const cloud = new HttpClient(fetchMock);
class DummyAIClient implements AIClient { class DummyAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(private response = 'lorem ipsum') {} constructor(private response = 'lorem ipsum') {}
async evaluate(_prompt: string) { async evaluate(_prompt: PromptMessage[]) {
return this.response; return this.response;
} }
} }

View File

@ -1,7 +1,19 @@
import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { AnthropicAIClient } from '$lib/ai/anthropicClient';
import { ButlerAIClient } from '$lib/ai/butlerClient'; import { ButlerAIClient } from '$lib/ai/butlerClient';
import {
DEFAULT_OLLAMA_ENDPOINT,
DEFAULT_OLLAMA_MODEL_NAME,
OllamaClient
} from '$lib/ai/ollamaClient';
import { OpenAIClient } from '$lib/ai/openAIClient'; import { OpenAIClient } from '$lib/ai/openAIClient';
import { OpenAIModelName, type AIClient, AnthropicModelName } from '$lib/ai/types'; import {
OpenAIModelName,
type AIClient,
AnthropicModelName,
ModelKind,
type PromptMessage,
MessageRole
} from '$lib/ai/types';
import { splitMessage } from '$lib/utils/commitMessage'; import { splitMessage } from '$lib/utils/commitMessage';
import * as toasts from '$lib/utils/toasts'; import * as toasts from '$lib/utils/toasts';
import OpenAI from 'openai'; import OpenAI from 'openai';
@ -11,39 +23,6 @@ import type { Hunk } from '$lib/vbranches/types';
const maxDiffLengthLimitForAPI = 5000; const maxDiffLengthLimitForAPI = 5000;
const defaultCommitTemplate = `
Please could you write a commit message for my changes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
Only respond with the commit message.
%{brief_style}
%{emoji_style}
Here is my git diff:
%{diff}
`;
const defaultBranchTemplate = `
Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
%{diff}
`;
export enum ModelKind {
OpenAI = 'openai',
Anthropic = 'anthropic'
}
export enum KeyOption { export enum KeyOption {
BringYourOwn = 'bringYourOwn', BringYourOwn = 'bringYourOwn',
ButlerAPI = 'butlerAPI' ButlerAPI = 'butlerAPI'
@ -57,20 +36,22 @@ export enum GitAIConfigKey {
AnthropicKeyOption = 'gitbutler.aiAnthropicKeyOption', AnthropicKeyOption = 'gitbutler.aiAnthropicKeyOption',
AnthropicModelName = 'gitbutler.aiAnthropicModelName', AnthropicModelName = 'gitbutler.aiAnthropicModelName',
AnthropicKey = 'gitbutler.aiAnthropicKey', AnthropicKey = 'gitbutler.aiAnthropicKey',
DiffLengthLimit = 'gitbutler.diffLengthLimit' DiffLengthLimit = 'gitbutler.diffLengthLimit',
OllamaEndpoint = 'gitbutler.aiOllamaEndpoint',
OllamaModelName = 'gitbutler.aiOllamaModelName'
} }
type SummarizeCommitOpts = { type SummarizeCommitOpts = {
hunks: Hunk[]; hunks: Hunk[];
useEmojiStyle?: boolean; useEmojiStyle?: boolean;
useBriefStyle?: boolean; useBriefStyle?: boolean;
commitTemplate?: string; commitTemplate?: PromptMessage[];
userToken?: string; userToken?: string;
}; };
type SummarizeBranchOpts = { type SummarizeBranchOpts = {
hunks: Hunk[]; hunks: Hunk[];
branchTemplate?: string; branchTemplate?: PromptMessage[];
userToken?: string; userToken?: string;
}; };
@ -84,7 +65,7 @@ export function buildDiff(hunks: Hunk[], limit: number) {
function shuffle<T>(items: T[]): T[] { function shuffle<T>(items: T[]): T[] {
return items return items
.map((item) => ({ item, value: Math.random() })) .map((item) => ({ item, value: Math.random() }))
.sort() .sort(({ value: a }, { value: b }) => a - b)
.map((item) => item.item); .map((item) => item.item);
} }
@ -159,6 +140,20 @@ export class AIService {
} }
} }
async getOllamaEndpoint() {
return await this.gitConfig.getWithDefault<string>(
GitAIConfigKey.OllamaEndpoint,
DEFAULT_OLLAMA_ENDPOINT
);
}
async getOllamaModelName() {
return await this.gitConfig.getWithDefault<string>(
GitAIConfigKey.OllamaModelName,
DEFAULT_OLLAMA_MODEL_NAME
);
}
async usingGitButlerAPI() { async usingGitButlerAPI() {
const modelKind = await this.getModelKind(); const modelKind = await this.getModelKind();
const openAIKeyOption = await this.getOpenAIKeyOption(); const openAIKeyOption = await this.getOpenAIKeyOption();
@ -176,13 +171,19 @@ export class AIService {
const modelKind = await this.getModelKind(); const modelKind = await this.getModelKind();
const openAIKey = await this.getOpenAIKey(); const openAIKey = await this.getOpenAIKey();
const anthropicKey = await this.getAnthropicKey(); const anthropicKey = await this.getAnthropicKey();
const ollamaEndpoint = await this.getOllamaEndpoint();
const ollamaModelName = await this.getOllamaModelName();
if (await this.usingGitButlerAPI()) return !!userToken; if (await this.usingGitButlerAPI()) return !!userToken;
const openAIActiveAndKeyProvided = modelKind == ModelKind.OpenAI && !!openAIKey; const openAIActiveAndKeyProvided = modelKind == ModelKind.OpenAI && !!openAIKey;
const anthropicActiveAndKeyProvided = modelKind == ModelKind.Anthropic && !!anthropicKey; const anthropicActiveAndKeyProvided = modelKind == ModelKind.Anthropic && !!anthropicKey;
const ollamaActiveAndEndpointProvided =
modelKind == ModelKind.Ollama && !!ollamaEndpoint && !!ollamaModelName;
return openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided; return (
openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided
);
} }
// This optionally returns a summarizer. There are a few conditions for how this may occur // This optionally returns a summarizer. There are a few conditions for how this may occur
@ -199,6 +200,12 @@ export class AIService {
return new ButlerAIClient(this.cloud, userToken, modelKind); return new ButlerAIClient(this.cloud, userToken, modelKind);
} }
if (modelKind == ModelKind.Ollama) {
const ollamaEndpoint = await this.getOllamaEndpoint();
const ollamaModelName = await this.getOllamaModelName();
return new OllamaClient(ollamaEndpoint, ollamaModelName);
}
if (modelKind == ModelKind.OpenAI) { if (modelKind == ModelKind.OpenAI) {
const openAIModelName = await this.getOpenAIModleName(); const openAIModelName = await this.getOpenAIModleName();
const openAIKey = await this.getOpenAIKey(); const openAIKey = await this.getOpenAIKey();
@ -233,24 +240,37 @@ export class AIService {
hunks, hunks,
useEmojiStyle = false, useEmojiStyle = false,
useBriefStyle = false, useBriefStyle = false,
commitTemplate = defaultCommitTemplate, commitTemplate,
userToken userToken
}: SummarizeCommitOpts) { }: SummarizeCommitOpts) {
const aiClient = await this.buildClient(userToken); const aiClient = await this.buildClient(userToken);
if (!aiClient) return; if (!aiClient) return;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
let prompt = commitTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate;
const prompt = defaultedCommitTemplate.map((promptMessage) => {
if (promptMessage.role != MessageRole.User) {
return promptMessage;
}
let content = promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const briefPart = useBriefStyle const briefPart = useBriefStyle
? 'The commit message must be only one sentence and as short as possible.' ? 'The commit message must be only one sentence and as short as possible.'
: ''; : '';
prompt = prompt.replaceAll('%{brief_style}', briefPart); content = content.replaceAll('%{brief_style}', briefPart);
const emojiPart = useEmojiStyle const emojiPart = useEmojiStyle
? 'Make use of GitMoji in the title prefix.' ? 'Make use of GitMoji in the title prefix.'
: "Don't use any emoji."; : "Don't use any emoji.";
prompt = prompt.replaceAll('%{emoji_style}', emojiPart); content = content.replaceAll('%{emoji_style}', emojiPart);
return {
role: MessageRole.User,
content
};
});
let message = await aiClient.evaluate(prompt); let message = await aiClient.evaluate(prompt);
@ -262,16 +282,23 @@ export class AIService {
return description ? `${title}\n\n${description}` : title; return description ? `${title}\n\n${description}` : title;
} }
async summarizeBranch({ async summarizeBranch({ hunks, branchTemplate, userToken = undefined }: SummarizeBranchOpts) {
hunks,
branchTemplate = defaultBranchTemplate,
userToken = undefined
}: SummarizeBranchOpts) {
const aiClient = await this.buildClient(userToken); const aiClient = await this.buildClient(userToken);
if (!aiClient) return; if (!aiClient) return;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
const prompt = branchTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit)); const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate;
const prompt = defaultedBranchTemplate.map((promptMessage) => {
if (promptMessage.role != MessageRole.User) {
return promptMessage;
}
return {
role: MessageRole.User,
content: promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit))
};
});
const message = await aiClient.evaluate(prompt); const message = await aiClient.evaluate(prompt);
return message.replaceAll(' ', '-').replaceAll('\n', '-'); return message.replaceAll(' ', '-').replaceAll('\n', '-');
} }

View File

@ -1,6 +1,7 @@
export enum ModelKind { export enum ModelKind {
OpenAI = 'openai', OpenAI = 'openai',
Anthropic = 'anthropic' Anthropic = 'anthropic',
Ollama = 'ollama'
} }
export enum OpenAIModelName { export enum OpenAIModelName {
@ -16,8 +17,9 @@ export enum AnthropicModelName {
} }
export enum MessageRole { export enum MessageRole {
System = 'system',
User = 'user', User = 'user',
Assistant = 'assisstant' Assistant = 'assistant'
} }
export interface PromptMessage { export interface PromptMessage {
@ -26,5 +28,8 @@ export interface PromptMessage {
} }
export interface AIClient { export interface AIClient {
evaluate(prompt: string): Promise<string>; evaluate(prompt: PromptMessage[]): Promise<string>;
defaultBranchTemplate: PromptMessage[];
defaultCommitTemplate: PromptMessage[];
} }

View File

@ -5,3 +5,19 @@ export function isDefined<T>(file: T | undefined | null): file is T {
export function notNull<T>(file: T | undefined | null): file is T { export function notNull<T>(file: T | undefined | null): file is T {
return file !== null; return file !== null;
} }
export type UnknownObject = Record<string, unknown>;
/**
* Checks if the provided value is a non-empty object.
* @param something - The value to be checked.
* @returns A boolean indicating whether the value is a non-empty object.
*/
export function isNonEmptyObject(something: unknown): something is UnknownObject {
return (
typeof something === 'object' &&
something !== null &&
!Array.isArray(something) &&
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}

View File

@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import { AIService, GitAIConfigKey, KeyOption, ModelKind } from '$lib/ai/service'; import { AIService, GitAIConfigKey, KeyOption } from '$lib/ai/service';
import { OpenAIModelName, AnthropicModelName } from '$lib/ai/types'; import { OpenAIModelName, AnthropicModelName, ModelKind } from '$lib/ai/types';
import { GitConfigService } from '$lib/backend/gitConfigService'; import { GitConfigService } from '$lib/backend/gitConfigService';
import InfoMessage from '$lib/components/InfoMessage.svelte'; import InfoMessage from '$lib/components/InfoMessage.svelte';
import RadioButton from '$lib/components/RadioButton.svelte'; import RadioButton from '$lib/components/RadioButton.svelte';
@ -30,6 +30,8 @@
let anthropicKey: string | undefined; let anthropicKey: string | undefined;
let anthropicModelName: AnthropicModelName | undefined; let anthropicModelName: AnthropicModelName | undefined;
let diffLengthLimit: number | undefined; let diffLengthLimit: number | undefined;
let ollamaEndpoint: string | undefined;
let ollamaModel: string | undefined;
function setConfiguration(key: GitAIConfigKey, value: string | undefined) { function setConfiguration(key: GitAIConfigKey, value: string | undefined) {
if (!initialized) return; if (!initialized) return;
@ -48,6 +50,9 @@
$: setConfiguration(GitAIConfigKey.AnthropicKey, anthropicKey); $: setConfiguration(GitAIConfigKey.AnthropicKey, anthropicKey);
$: setConfiguration(GitAIConfigKey.DiffLengthLimit, diffLengthLimit?.toString()); $: setConfiguration(GitAIConfigKey.DiffLengthLimit, diffLengthLimit?.toString());
$: setConfiguration(GitAIConfigKey.OllamaEndpoint, ollamaEndpoint);
$: setConfiguration(GitAIConfigKey.OllamaModelName, ollamaModel);
onMount(async () => { onMount(async () => {
modelKind = await aiService.getModelKind(); modelKind = await aiService.getModelKind();
@ -61,6 +66,9 @@
diffLengthLimit = await aiService.getDiffLengthLimit(); diffLengthLimit = await aiService.getDiffLengthLimit();
ollamaEndpoint = await aiService.getOllamaEndpoint();
ollamaModel = await aiService.getOllamaModelName();
// Ensure reactive declarations have finished running before we set initialized to true // Ensure reactive declarations have finished running before we set initialized to true
await tick(); await tick();
@ -261,15 +269,31 @@
</SectionCard> </SectionCard>
{/if} {/if}
<SectionCard roundedTop={false} orientation="row" disabled={true}> <SectionCard
<svelte:fragment slot="title">Custom Endpoint</svelte:fragment> roundedTop={false}
<svelte:fragment slot="actions"> roundedBottom={modelKind != ModelKind.Ollama}
<RadioButton disabled={true} name="modelKind" /> orientation="row"
</svelte:fragment> labelFor="ollama"
<svelte:fragment slot="caption" bottomBorder={modelKind != ModelKind.Ollama}
>Support for custom AI endpoints is coming soon!</svelte:fragment
> >
<svelte:fragment slot="title">Ollama 🦙</svelte:fragment>
<svelte:fragment slot="actions">
<RadioButton name="modelKind" id="ollama" value={ModelKind.Ollama} />
</svelte:fragment>
</SectionCard> </SectionCard>
{#if modelKind == ModelKind.Ollama}
<SectionCard hasTopRadius={false} roundedTop={false} orientation="row" topDivider>
<div class="inputs-group">
<TextBox
label="Endpoint"
bind:value={ollamaEndpoint}
placeholder="http://127.0.0.1:11434"
/>
<TextBox label="Model" bind:value={ollamaModel} placeholder="llama3" />
</div>
</SectionCard>
{/if}
</form> </form>
<Spacer /> <Spacer />