Improve ai error handling (#4180)

* Introduce a result type

* Update AI error handling to use the result type

* Handle ollama json parse error

* Migrate using Error as the type that represents errors

* Remove now useless condition

* asdfasdf

* Use andThen

* Correct unit tests
This commit is contained in:
Caleb Owens 2024-06-27 21:50:44 +02:00 committed by GitHub
parent 518cc8b77e
commit dd0b4eccf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 329 additions and 135 deletions

View File

@ -1,8 +1,12 @@
import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts'; import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts';
import { type AIClient, type AnthropicModelName, type Prompt } from '$lib/ai/types';
import { buildFailureFromAny, ok, type Result } from '$lib/result';
import { fetch, Body } from '@tauri-apps/api/http'; import { fetch, Body } from '@tauri-apps/api/http';
import type { AIClient, AnthropicModelName, Prompt } from '$lib/ai/types';
type AnthropicAPIResponse = { content: { text: string }[] }; type AnthropicAPIResponse = {
content: { text: string }[];
error: { type: string; message: string };
};
export class AnthropicAIClient implements AIClient { export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
@ -13,7 +17,7 @@ export class AnthropicAIClient implements AIClient {
private modelName: AnthropicModelName private modelName: AnthropicModelName
) {} ) {}
async evaluate(prompt: Prompt) { async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const body = Body.json({ const body = Body.json({
messages: prompt, messages: prompt,
max_tokens: 1024, max_tokens: 1024,
@ -30,6 +34,12 @@ export class AnthropicAIClient implements AIClient {
body body
}); });
return response.data.content[0].text; if (response.ok && response.data?.content?.[0]?.text) {
return ok(response.data.content[0].text);
} else {
return buildFailureFromAny(
`Anthropic returned error code ${response.status} ${response.data?.error?.message}`
);
}
} }
} }

View File

@ -1,4 +1,5 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { map, type Result } from '$lib/result';
import type { AIClient, ModelKind, Prompt } from '$lib/ai/types'; import type { AIClient, ModelKind, Prompt } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient'; import type { HttpClient } from '$lib/backend/httpClient';
@ -12,16 +13,19 @@ export class ButlerAIClient implements AIClient {
private modelKind: ModelKind private modelKind: ModelKind
) {} ) {}
async evaluate(prompt: Prompt) { async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { const response = await this.cloud.postSafe<{ message: string }>(
body: { 'evaluate_prompt/predict.json',
messages: prompt, {
max_tokens: 400, body: {
model_kind: this.modelKind messages: prompt,
}, max_tokens: 400,
token: this.userToken model_kind: this.modelKind
}); },
token: this.userToken
}
);
return response.message; return map(response, ({ message }) => message);
} }
} }

View File

@ -1,5 +1,6 @@
import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types'; import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types';
import { andThen, buildFailureFromAny, ok, wrap, wrapAsync, type Result } from '$lib/result';
import { isNonEmptyObject } from '$lib/utils/typeguards'; import { isNonEmptyObject } from '$lib/utils/typeguards';
import { fetch, Body, Response } from '@tauri-apps/api/http'; import { fetch, Body, Response } from '@tauri-apps/api/http';
@ -81,15 +82,22 @@ export class OllamaClient implements AIClient {
private modelName: string private modelName: string
) {} ) {}
async evaluate(prompt: Prompt) { async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const messages = this.formatPrompt(prompt); const messages = this.formatPrompt(prompt);
const response = await this.chat(messages);
const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
throw new Error('Invalid response: ' + response.message.content);
}
return rawResponse.result; const responseResult = await this.chat(messages);
return andThen(responseResult, (response) => {
const rawResponseResult = wrap<unknown, Error>(() => JSON.parse(response.message.content));
return andThen(rawResponseResult, (rawResponse) => {
if (!isOllamaChatMessageFormat(rawResponse)) {
return buildFailureFromAny('Invalid response: ' + response.message.content);
}
return ok(rawResponse.result);
});
});
} }
/** /**
@ -124,17 +132,19 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
* @param request - The OllamaChatRequest object containing the request details. * @param request - The OllamaChatRequest object containing the request details.
* @returns A Promise that resolves to the Response object. * @returns A Promise that resolves to the Response object.
*/ */
private async fetchChat(request: OllamaChatRequest): Promise<Response<any>> { private async fetchChat(request: OllamaChatRequest): Promise<Result<Response<any>, Error>> {
const url = new URL(OllamaAPEndpoint.Chat, this.endpoint); const url = new URL(OllamaAPEndpoint.Chat, this.endpoint);
const body = Body.json(request); const body = Body.json(request);
const result = await fetch(url.toString(), { return await wrapAsync(
method: 'POST', async () =>
headers: { await fetch(url.toString(), {
'Content-Type': 'application/json' method: 'POST',
}, headers: {
body 'Content-Type': 'application/json'
}); },
return result; body
})
);
} }
/** /**
@ -142,13 +152,12 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
* *
* @param messages - An array of LLMChatMessage objects representing the chat messages. * @param messages - An array of LLMChatMessage objects representing the chat messages.
* @param options - Optional LLMRequestOptions object for specifying additional options. * @param options - Optional LLMRequestOptions object for specifying additional options.
* @throws Error if the response is invalid.
* @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model. * @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model.
*/ */
private async chat( private async chat(
messages: Prompt, messages: Prompt,
options?: OllamaRequestOptions options?: OllamaRequestOptions
): Promise<OllamaChatResponse> { ): Promise<Result<OllamaChatResponse, Error>> {
const result = await this.fetchChat({ const result = await this.fetchChat({
model: this.modelName, model: this.modelName,
stream: false, stream: false,
@ -157,10 +166,12 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
format: 'json' format: 'json'
}); });
if (!isOllamaChatResponse(result.data)) { return andThen(result, (result) => {
throw new Error('Invalid response\n' + JSON.stringify(result.data)); if (!isOllamaChatResponse(result.data)) {
} return buildFailureFromAny('Invalid response\n' + JSON.stringify(result.data));
}
return result.data; return ok(result.data);
});
} }
} }

View File

@ -1,6 +1,8 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts'; import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { andThen, buildFailureFromAny, ok, wrapAsync, type Result } from '$lib/result';
import type { OpenAIModelName, Prompt, AIClient } from '$lib/ai/types'; import type { OpenAIModelName, Prompt, AIClient } from '$lib/ai/types';
import type OpenAI from 'openai'; import type OpenAI from 'openai';
import type { ChatCompletion } from 'openai/resources/index.mjs';
export class OpenAIClient implements AIClient { export class OpenAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE; defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
@ -11,13 +13,21 @@ export class OpenAIClient implements AIClient {
private openAI: OpenAI private openAI: OpenAI
) {} ) {}
async evaluate(prompt: Prompt) { async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const response = await this.openAI.chat.completions.create({ const responseResult = await wrapAsync<ChatCompletion, Error>(async () => {
messages: prompt, return await this.openAI.chat.completions.create({
model: this.modelName, messages: prompt,
max_tokens: 400 model: this.modelName,
max_tokens: 400
});
}); });
return response.choices[0].message.content || ''; return andThen(responseResult, (response) => {
if (response.choices[0]?.message.content) {
return ok(response.choices[0]?.message.content);
} else {
return buildFailureFromAny('Open AI generated an empty message');
}
});
} }
} }

View File

@ -11,7 +11,7 @@ import {
type Prompt type Prompt
} from '$lib/ai/types'; } from '$lib/ai/types';
import { HttpClient } from '$lib/backend/httpClient'; import { HttpClient } from '$lib/backend/httpClient';
import * as toasts from '$lib/utils/toasts'; import { buildFailureFromAny, ok, unwrap, type Result } from '$lib/result';
import { Hunk } from '$lib/vbranches/types'; import { Hunk } from '$lib/vbranches/types';
import { plainToInstance } from 'class-transformer'; import { plainToInstance } from 'class-transformer';
import { expect, test, describe, vi } from 'vitest'; import { expect, test, describe, vi } from 'vitest';
@ -56,8 +56,8 @@ class DummyAIClient implements AIClient {
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE; defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(private response = 'lorem ipsum') {} constructor(private response = 'lorem ipsum') {}
async evaluate(_prompt: Prompt) { async evaluate(_prompt: Prompt): Promise<Result<string, Error>> {
return this.response; return ok(this.response);
} }
} }
@ -116,16 +116,14 @@ describe.concurrent('AIService', () => {
test('With default configuration, When a user token is provided. It returns ButlerAIClient', async () => { test('With default configuration, When a user token is provided. It returns ButlerAIClient', async () => {
const aiService = buildDefaultAIService(); const aiService = buildDefaultAIService();
expect(await aiService.buildClient('token')).toBeInstanceOf(ButlerAIClient); expect(unwrap(await aiService.buildClient('token'))).toBeInstanceOf(ButlerAIClient);
}); });
test('With default configuration, When a user is undefined. It returns undefined', async () => { test('With default configuration, When a user is undefined. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const aiService = buildDefaultAIService(); const aiService = buildDefaultAIService();
expect(await aiService.buildClient()).toBe(undefined); expect(await aiService.buildClient()).toStrictEqual(
expect(toastErrorSpy).toHaveBeenLastCalledWith( buildFailureFromAny("When using GitButler's API to summarize code, you must be logged in")
"When using GitButler's API to summarize code, you must be logged in"
); );
}); });
@ -137,11 +135,10 @@ describe.concurrent('AIService', () => {
}); });
const aiService = new AIService(gitConfig, cloud); const aiService = new AIService(gitConfig, cloud);
expect(await aiService.buildClient()).toBeInstanceOf(OpenAIClient); expect(unwrap(await aiService.buildClient())).toBeInstanceOf(OpenAIClient);
}); });
test('When token is bring your own, When a openAI token is blank. It returns undefined', async () => { test('When token is bring your own, When a openAI token is blank. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const gitConfig = new DummyGitConfigService({ const gitConfig = new DummyGitConfigService({
...defaultGitConfig, ...defaultGitConfig,
[GitAIConfigKey.OpenAIKeyOption]: KeyOption.BringYourOwn, [GitAIConfigKey.OpenAIKeyOption]: KeyOption.BringYourOwn,
@ -149,9 +146,10 @@ describe.concurrent('AIService', () => {
}); });
const aiService = new AIService(gitConfig, cloud); const aiService = new AIService(gitConfig, cloud);
expect(await aiService.buildClient()).toBe(undefined); expect(await aiService.buildClient()).toStrictEqual(
expect(toastErrorSpy).toHaveBeenLastCalledWith( buildFailureFromAny(
'When using OpenAI in a bring your own key configuration, you must provide a valid token' 'When using OpenAI in a bring your own key configuration, you must provide a valid token'
)
); );
}); });
@ -164,11 +162,10 @@ describe.concurrent('AIService', () => {
}); });
const aiService = new AIService(gitConfig, cloud); const aiService = new AIService(gitConfig, cloud);
expect(await aiService.buildClient()).toBeInstanceOf(AnthropicAIClient); expect(unwrap(await aiService.buildClient())).toBeInstanceOf(AnthropicAIClient);
}); });
test('When ai provider is Anthropic, When token is bring your own, When an anthropic token is blank. It returns undefined', async () => { test('When ai provider is Anthropic, When token is bring your own, When an anthropic token is blank. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const gitConfig = new DummyGitConfigService({ const gitConfig = new DummyGitConfigService({
...defaultGitConfig, ...defaultGitConfig,
[GitAIConfigKey.ModelProvider]: ModelKind.Anthropic, [GitAIConfigKey.ModelProvider]: ModelKind.Anthropic,
@ -177,9 +174,10 @@ describe.concurrent('AIService', () => {
}); });
const aiService = new AIService(gitConfig, cloud); const aiService = new AIService(gitConfig, cloud);
expect(await aiService.buildClient()).toBe(undefined); expect(await aiService.buildClient()).toStrictEqual(
expect(toastErrorSpy).toHaveBeenLastCalledWith( buildFailureFromAny(
'When using Anthropic in a bring your own key configuration, you must provide a valid token' 'When using Anthropic in a bring your own key configuration, you must provide a valid token'
)
); );
}); });
}); });
@ -188,9 +186,13 @@ describe.concurrent('AIService', () => {
test('When buildModel returns undefined, it returns undefined', async () => { test('When buildModel returns undefined, it returns undefined', async () => {
const aiService = buildDefaultAIService(); const aiService = buildDefaultAIService();
vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)()); vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => buildFailureFromAny('Failed to build'))()
);
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe(undefined); expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
buildFailureFromAny('Failed to build')
);
}); });
test('When the AI returns a single line commit message, it returns it unchanged', async () => { test('When the AI returns a single line commit message, it returns it unchanged', async () => {
@ -199,10 +201,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'single line commit'; const clientResponse = 'single line commit';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('single line commit'); expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
ok('single line commit')
);
}); });
test('When the AI returns a title and body that is split by a single new line, it replaces it with two', async () => { test('When the AI returns a title and body that is split by a single new line, it replaces it with two', async () => {
@ -211,10 +215,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'one\nnew line'; const clientResponse = 'one\nnew line';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('one\n\nnew line'); expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
ok('one\n\nnew line')
);
}); });
test('When the commit is in brief mode, When the AI returns a title and body, it takes just the title', async () => { test('When the commit is in brief mode, When the AI returns a title and body, it takes just the title', async () => {
@ -223,12 +229,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'one\nnew line'; const clientResponse = 'one\nnew line';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeCommit({ hunks: exampleHunks, useBriefStyle: true })).toBe( expect(
'one' await aiService.summarizeCommit({ hunks: exampleHunks, useBriefStyle: true })
); ).toStrictEqual(ok('one'));
}); });
}); });
@ -236,9 +242,13 @@ describe.concurrent('AIService', () => {
test('When buildModel returns undefined, it returns undefined', async () => { test('When buildModel returns undefined, it returns undefined', async () => {
const aiService = buildDefaultAIService(); const aiService = buildDefaultAIService();
vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)()); vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => buildFailureFromAny('Failed to build client'))()
);
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(undefined); expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
buildFailureFromAny('Failed to build client')
);
}); });
test('When the AI client returns a string with spaces, it replaces them with hypens', async () => { test('When the AI client returns a string with spaces, it replaces them with hypens', async () => {
@ -247,10 +257,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with spaces included'; const clientResponse = 'with spaces included';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe('with-spaces-included'); expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
ok('with-spaces-included')
);
}); });
test('When the AI client returns multiple lines, it replaces them with hypens', async () => { test('When the AI client returns multiple lines, it replaces them with hypens', async () => {
@ -259,11 +271,11 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with\nnew\nlines\nincluded'; const clientResponse = 'with\nnew\nlines\nincluded';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe( expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
'with-new-lines-included' ok('with-new-lines-included')
); );
}); });
@ -273,11 +285,11 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with\nnew lines\nincluded'; const clientResponse = 'with\nnew lines\nincluded';
vi.spyOn(aiService, 'buildClient').mockReturnValue( vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))() (async () => ok<AIClient, Error>(new DummyAIClient(clientResponse)))()
); );
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe( expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
'with-new-lines-included' ok('with-new-lines-included')
); );
}); });
}); });

View File

@ -14,8 +14,8 @@ import {
MessageRole, MessageRole,
type Prompt type Prompt
} from '$lib/ai/types'; } from '$lib/ai/types';
import { buildFailureFromAny, isFailure, ok, type Result } from '$lib/result';
import { splitMessage } from '$lib/utils/commitMessage'; import { splitMessage } from '$lib/utils/commitMessage';
import * as toasts from '$lib/utils/toasts';
import OpenAI from 'openai'; import OpenAI from 'openai';
import type { GitConfigService } from '$lib/backend/gitConfigService'; import type { GitConfigService } from '$lib/backend/gitConfigService';
import type { HttpClient } from '$lib/backend/httpClient'; import type { HttpClient } from '$lib/backend/httpClient';
@ -189,21 +189,22 @@ export class AIService {
// This optionally returns a summarizer. There are a few conditions for how this may occur // This optionally returns a summarizer. There are a few conditions for how this may occur
// Firstly, if the user has opted to use the GB API and isn't logged in, it will return undefined // Firstly, if the user has opted to use the GB API and isn't logged in, it will return undefined
// Secondly, if the user has opted to bring their own key but hasn't provided one, it will return undefined // Secondly, if the user has opted to bring their own key but hasn't provided one, it will return undefined
async buildClient(userToken?: string): Promise<undefined | AIClient> { async buildClient(userToken?: string): Promise<Result<AIClient, Error>> {
const modelKind = await this.getModelKind(); const modelKind = await this.getModelKind();
if (await this.usingGitButlerAPI()) { if (await this.usingGitButlerAPI()) {
if (!userToken) { if (!userToken) {
toasts.error("When using GitButler's API to summarize code, you must be logged in"); return buildFailureFromAny(
return; "When using GitButler's API to summarize code, you must be logged in"
);
} }
return new ButlerAIClient(this.cloud, userToken, modelKind); return ok(new ButlerAIClient(this.cloud, userToken, modelKind));
} }
if (modelKind === ModelKind.Ollama) { if (modelKind === ModelKind.Ollama) {
const ollamaEndpoint = await this.getOllamaEndpoint(); const ollamaEndpoint = await this.getOllamaEndpoint();
const ollamaModelName = await this.getOllamaModelName(); const ollamaModelName = await this.getOllamaModelName();
return new OllamaClient(ollamaEndpoint, ollamaModelName); return ok(new OllamaClient(ollamaEndpoint, ollamaModelName));
} }
if (modelKind === ModelKind.OpenAI) { if (modelKind === ModelKind.OpenAI) {
@ -211,14 +212,13 @@ export class AIService {
const openAIKey = await this.getOpenAIKey(); const openAIKey = await this.getOpenAIKey();
if (!openAIKey) { if (!openAIKey) {
toasts.error( return buildFailureFromAny(
'When using OpenAI in a bring your own key configuration, you must provide a valid token' 'When using OpenAI in a bring your own key configuration, you must provide a valid token'
); );
return;
} }
const openAI = new OpenAI({ apiKey: openAIKey, dangerouslyAllowBrowser: true }); const openAI = new OpenAI({ apiKey: openAIKey, dangerouslyAllowBrowser: true });
return new OpenAIClient(openAIModelName, openAI); return ok(new OpenAIClient(openAIModelName, openAI));
} }
if (modelKind === ModelKind.Anthropic) { if (modelKind === ModelKind.Anthropic) {
@ -226,14 +226,15 @@ export class AIService {
const anthropicKey = await this.getAnthropicKey(); const anthropicKey = await this.getAnthropicKey();
if (!anthropicKey) { if (!anthropicKey) {
toasts.error( return buildFailureFromAny(
'When using Anthropic in a bring your own key configuration, you must provide a valid token' 'When using Anthropic in a bring your own key configuration, you must provide a valid token'
); );
return;
} }
return new AnthropicAIClient(anthropicKey, anthropicModelName); return ok(new AnthropicAIClient(anthropicKey, anthropicModelName));
} }
return buildFailureFromAny('Failed to build ai client');
} }
async summarizeCommit({ async summarizeCommit({
@ -242,9 +243,10 @@ export class AIService {
useBriefStyle = false, useBriefStyle = false,
commitTemplate, commitTemplate,
userToken userToken
}: SummarizeCommitOpts) { }: SummarizeCommitOpts): Promise<Result<string, Error>> {
const aiClient = await this.buildClient(userToken); const aiClientResult = await this.buildClient(userToken);
if (!aiClient) return; if (isFailure(aiClientResult)) return aiClientResult;
const aiClient = aiClientResult.value;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate; const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate;
@ -272,19 +274,26 @@ export class AIService {
}; };
}); });
let message = await aiClient.evaluate(prompt); const messageResult = await aiClient.evaluate(prompt);
if (isFailure(messageResult)) return messageResult;
let message = messageResult.value;
if (useBriefStyle) { if (useBriefStyle) {
message = message.split('\n')[0]; message = message.split('\n')[0];
} }
const { title, description } = splitMessage(message); const { title, description } = splitMessage(message);
return description ? `${title}\n\n${description}` : title; return ok(description ? `${title}\n\n${description}` : title);
} }
async summarizeBranch({ hunks, branchTemplate, userToken = undefined }: SummarizeBranchOpts) { async summarizeBranch({
const aiClient = await this.buildClient(userToken); hunks,
if (!aiClient) return; branchTemplate,
userToken = undefined
}: SummarizeBranchOpts): Promise<Result<string, Error>> {
const aiClientResult = await this.buildClient(userToken);
if (isFailure(aiClientResult)) return aiClientResult;
const aiClient = aiClientResult.value;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI(); const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate; const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate;
@ -299,7 +308,10 @@ export class AIService {
}; };
}); });
const message = await aiClient.evaluate(prompt); const messageResult = await aiClient.evaluate(prompt);
return message.replaceAll(' ', '-').replaceAll('\n', '-'); if (isFailure(messageResult)) return messageResult;
const message = messageResult.value;
return ok(message.replaceAll(' ', '-').replaceAll('\n', '-'));
} }
} }

View File

@ -1,4 +1,5 @@
import type { Persisted } from '$lib/persisted/persisted'; import type { Persisted } from '$lib/persisted/persisted';
import type { Result } from '$lib/result';
export enum ModelKind { export enum ModelKind {
OpenAI = 'openai', OpenAI = 'openai',
@ -33,7 +34,7 @@ export interface PromptMessage {
export type Prompt = PromptMessage[]; export type Prompt = PromptMessage[];
export interface AIClient { export interface AIClient {
evaluate(prompt: Prompt): Promise<string>; evaluate(prompt: Prompt): Promise<Result<string, Error>>;
defaultBranchTemplate: Prompt; defaultBranchTemplate: Prompt;
defaultCommitTemplate: Prompt; defaultCommitTemplate: Prompt;

View File

@ -1,3 +1,4 @@
import { wrapAsync } from '$lib/result';
import { PUBLIC_API_BASE_URL } from '$env/static/public'; import { PUBLIC_API_BASE_URL } from '$env/static/public';
export const API_URL = new URL('/api/', PUBLIC_API_BASE_URL); export const API_URL = new URL('/api/', PUBLIC_API_BASE_URL);
@ -47,21 +48,41 @@ export class HttpClient {
return await this.request<T>(path, { ...opts, method: 'GET' }); return await this.request<T>(path, { ...opts, method: 'GET' });
} }
async getSafe<T>(path: string, opts?: Omit<RequestOptions, 'body'>) {
return await wrapAsync<T, Error>(async () => await this.get<T>(path, opts));
}
async post<T>(path: string, opts?: RequestOptions) { async post<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'POST' }); return await this.request<T>(path, { ...opts, method: 'POST' });
} }
async postSafe<T>(path: string, opts?: RequestOptions) {
return await wrapAsync<T, Error>(async () => await this.post<T>(path, opts));
}
async put<T>(path: string, opts?: RequestOptions) { async put<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PUT' }); return await this.request<T>(path, { ...opts, method: 'PUT' });
} }
async putSafe<T>(path: string, opts?: RequestOptions) {
return await wrapAsync<T, Error>(async () => await this.put<T>(path, opts));
}
async patch<T>(path: string, opts?: RequestOptions) { async patch<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PATCH' }); return await this.request<T>(path, { ...opts, method: 'PATCH' });
} }
async patchSafe<T>(path: string, opts?: RequestOptions) {
return await wrapAsync<T, Error>(async () => await this.patch<T>(path, opts));
}
async delete<T>(path: string, opts?: RequestOptions) { async delete<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'DELETE' }); return await this.request<T>(path, { ...opts, method: 'DELETE' });
} }
async deleteSafe<T>(path: string, opts?: RequestOptions) {
return await wrapAsync<T, Error>(async () => await this.delete<T>(path, opts));
}
} }
function getApiUrl(path: string) { function getApiUrl(path: string) {

View File

@ -17,6 +17,7 @@
import BranchFiles from '$lib/file/BranchFiles.svelte'; import BranchFiles from '$lib/file/BranchFiles.svelte';
import { showError } from '$lib/notifications/toasts'; import { showError } from '$lib/notifications/toasts';
import { persisted } from '$lib/persisted/persisted'; import { persisted } from '$lib/persisted/persisted';
import { isFailure } from '$lib/result';
import { SETTINGS, type Settings } from '$lib/settings/userSettings'; import { SETTINGS, type Settings } from '$lib/settings/userSettings';
import Resizer from '$lib/shared/Resizer.svelte'; import Resizer from '$lib/shared/Resizer.svelte';
import { User } from '$lib/stores/user'; import { User } from '$lib/stores/user';
@ -64,21 +65,25 @@
const hunks = branch.files.flatMap((f) => f.hunks); const hunks = branch.files.flatMap((f) => f.hunks);
try { const prompt = promptService.selectedBranchPrompt(project.id);
const prompt = promptService.selectedBranchPrompt(project.id); const messageResult = await aiService.summarizeBranch({
const message = await aiService.summarizeBranch({ hunks,
hunks, userToken: $user?.access_token,
userToken: $user?.access_token, branchTemplate: prompt
branchTemplate: prompt });
});
if (message && message !== branch.name) { if (isFailure(messageResult)) {
branch.name = message; console.error(messageResult.failure);
branchController.updateBranchName(branch.id, branch.name); showError('Failed to generate branch name', messageResult.failure);
}
} catch (e) { return;
console.error(e); }
showError('Failed to generate branch name', e);
const message = messageResult.value;
if (message && message !== branch.name) {
branch.name = message;
branchController.updateBranchName(branch.id, branch.name);
} }
} }

View File

@ -11,6 +11,7 @@
projectCommitGenerationUseEmojis projectCommitGenerationUseEmojis
} from '$lib/config/config'; } from '$lib/config/config';
import { showError } from '$lib/notifications/toasts'; import { showError } from '$lib/notifications/toasts';
import { isFailure } from '$lib/result';
import Checkbox from '$lib/shared/Checkbox.svelte'; import Checkbox from '$lib/shared/Checkbox.svelte';
import DropDownButton from '$lib/shared/DropDownButton.svelte'; import DropDownButton from '$lib/shared/DropDownButton.svelte';
import Icon from '$lib/shared/Icon.svelte'; import Icon from '$lib/shared/Icon.svelte';
@ -75,27 +76,35 @@
} }
aiLoading = true; aiLoading = true;
try {
const prompt = promptService.selectedCommitPrompt(project.id);
console.log(prompt);
const generatedMessage = await aiService.summarizeCommit({
hunks,
useEmojiStyle: $commitGenerationUseEmojis,
useBriefStyle: $commitGenerationExtraConcise,
userToken: $user?.access_token,
commitTemplate: prompt
});
if (generatedMessage) { const prompt = promptService.selectedCommitPrompt(project.id);
commitMessage = generatedMessage;
} else { const generatedMessageResult = await aiService.summarizeCommit({
throw new Error('Prompt generated no response'); hunks,
} useEmojiStyle: $commitGenerationUseEmojis,
} catch (e: any) { useBriefStyle: $commitGenerationExtraConcise,
showError('Failed to generate commit message', e); userToken: $user?.access_token,
} finally { commitTemplate: prompt
});
if (isFailure(generatedMessageResult)) {
showError('Failed to generate commit message', generatedMessageResult.failure);
aiLoading = false; aiLoading = false;
return;
} }
const generatedMessage = generatedMessageResult.value;
if (generatedMessage) {
commitMessage = generatedMessage;
} else {
const errorMessage = 'Prompt generated no response';
showError(errorMessage, undefined);
aiLoading = false;
return;
}
aiLoading = false;
} }
onMount(async () => { onMount(async () => {

99
app/src/lib/result.ts Normal file
View File

@ -0,0 +1,99 @@
export class Panic extends Error {}
export type OkVariant<Ok> = {
ok: true;
value: Ok;
};
export type FailureVariant<Err> = {
ok: false;
failure: Err;
};
export type Result<Ok, Err> = OkVariant<Ok> | FailureVariant<Err>;
export function isOk<Ok, Err>(
subject: OkVariant<Ok> | FailureVariant<Err>
): subject is OkVariant<Ok> {
return subject.ok;
}
export function isFailure<Ok, Err>(
subject: OkVariant<Ok> | FailureVariant<Err>
): subject is FailureVariant<Err> {
return !subject.ok;
}
export function ok<Ok, Err>(value: Ok): Result<Ok, Err> {
return { ok: true, value };
}
export function failure<Ok, Err>(value: Err): Result<Ok, Err> {
return { ok: false, failure: value };
}
export function buildFailureFromAny<Ok>(value: any): Result<Ok, Error> {
if (value instanceof Error) {
return failure(value);
} else {
return failure(new Error(String(value)));
}
}
export function wrap<Ok, Err>(subject: () => Ok): Result<Ok, Err> {
try {
return ok(subject());
} catch (e) {
return failure(e as Err);
}
}
export async function wrapAsync<Ok, Err>(subject: () => Promise<Ok>): Promise<Result<Ok, Err>> {
try {
return ok(await subject());
} catch (e) {
return failure(e as Err);
}
}
export function unwrap<Ok, Err>(subject: Result<Ok, Err>): Ok {
if (isOk(subject)) {
return subject.value;
} else {
if (subject.failure instanceof Error) {
throw subject.failure;
} else {
throw new Panic(String(subject.failure));
}
}
}
export function unwrapOr<Ok, Err, Or>(subject: Result<Ok, Err>, or: Or): Ok | Or {
if (isOk(subject)) {
return subject.value;
} else {
return or;
}
}
export function map<Ok, Err, NewOk>(
subject: Result<Ok, Err>,
transformation: (ok: Ok) => NewOk
): Result<NewOk, Err> {
if (isOk(subject)) {
return ok(transformation(subject.value));
} else {
return subject;
}
}
export function andThen<Ok, Err, NewOk>(
subject: Result<Ok, Err>,
transformation: (ok: Ok) => Result<NewOk, Err>
): Result<NewOk, Err> {
if (isOk(subject)) {
return transformation(subject.value);
} else {
return subject;
}
}