Merge pull request #5372 from gitbutlerapp/e-branch-2

ai: Stream the Butler AI messages
This commit is contained in:
Esteban Vega 2024-11-01 14:02:55 +01:00 committed by GitHub
commit d78d84a887
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 16 deletions

View File

@ -4,8 +4,9 @@ import {
SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE,
SHORT_DEFAULT_PR_TEMPLATE SHORT_DEFAULT_PR_TEMPLATE
} from '$lib/ai/prompts'; } from '$lib/ai/prompts';
import { ModelKind, type AIClient, type Prompt } from '$lib/ai/types'; import { ModelKind, type AIClient, type AIEvalOptions, type Prompt } from '$lib/ai/types';
import { map, wrapAsync, type Result } from '$lib/result'; import { andThenAsync, ok, wrapAsync, type Result } from '$lib/result';
import { stringStreamGenerator } from '$lib/utils/promise';
import type { HttpClient } from '@gitbutler/shared/httpClient'; import type { HttpClient } from '@gitbutler/shared/httpClient';
function splitPromptMessagesIfNecessary( function splitPromptMessagesIfNecessary(
@ -33,20 +34,33 @@ export class ButlerAIClient implements AIClient {
private modelKind: ModelKind private modelKind: ModelKind
) {} ) {}
async evaluate(prompt: Prompt): Promise<Result<string, Error>> { async evaluate(prompt: Prompt, options?: AIEvalOptions): Promise<Result<string, Error>> {
const [messages, system] = splitPromptMessagesIfNecessary(this.modelKind, prompt); const [messages, system] = splitPromptMessagesIfNecessary(this.modelKind, prompt);
const response = await wrapAsync<{ message: string }, Error>( const response = await wrapAsync<Response, Error>(
async () => async () =>
await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', { await this.cloud.postRaw('ai/stream', {
body: { body: {
messages, messages,
system, system,
max_tokens: 400, max_tokens: 3600,
model_kind: this.modelKind model_kind: this.modelKind
} }
}) })
); );
return map(response, ({ message }) => message); return await andThenAsync(response, async (r) => {
const reader = r.body?.getReader();
if (!reader) {
return ok('');
}
const buffer: string[] = [];
for await (const chunk of stringStreamGenerator(reader)) {
options?.onToken?.(chunk);
buffer.push(chunk);
}
return ok(buffer.join(''));
});
} }
} }

View File

@ -74,12 +74,21 @@
const prompt = promptService.selectedCommitPrompt(project.id); const prompt = promptService.selectedCommitPrompt(project.id);
let firstToken = true;
const generatedMessageResult = await aiService.summarizeCommit({ const generatedMessageResult = await aiService.summarizeCommit({
hunks, hunks,
useEmojiStyle: $commitGenerationUseEmojis, useEmojiStyle: $commitGenerationUseEmojis,
useBriefStyle: $commitGenerationExtraConcise, useBriefStyle: $commitGenerationExtraConcise,
commitTemplate: prompt, commitTemplate: prompt,
branchName: $branch.name branchName: $branch.name,
onToken: (t) => {
if (firstToken) {
commitMessage = '';
firstToken = false;
}
commitMessage += t;
}
}); });
if (isFailure(generatedMessageResult)) { if (isFailure(generatedMessageResult)) {

View File

@ -253,7 +253,7 @@
directive: aiDescriptionDirective, directive: aiDescriptionDirective,
commitMessages: commits.map((c) => c.description), commitMessages: commits.map((c) => c.description),
prBodyTemplate: templateBody, prBodyTemplate: templateBody,
onToken: async (t) => { onToken: (t) => {
if (firstToken) { if (firstToken) {
inputBody = ''; inputBody = '';
firstToken = false; firstToken = false;

View File

@ -0,0 +1,15 @@
export async function* stringStreamGenerator(
reader: ReadableStreamDefaultReader<Uint8Array>
): AsyncGenerator<string, void, void> {
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
yield new TextDecoder().decode(value);
}
} finally {
reader.releaseLock();
}
}

View File

@ -35,10 +35,10 @@ export class HttpClient {
return new URL(path, this.apiUrl); return new URL(path, this.apiUrl);
} }
private async request<T>( private async request(
path: string, path: string,
opts: RequestOptions & { method: RequestMethod } opts: RequestOptions & { method: RequestMethod }
): Promise<T> { ): Promise<Response> {
const butlerHeaders = new Headers(DEFAULT_HEADERS); const butlerHeaders = new Headers(DEFAULT_HEADERS);
if (opts.headers) { if (opts.headers) {
@ -60,27 +60,39 @@ export class HttpClient {
body: formatBody(opts.body) body: formatBody(opts.body)
}); });
return response;
}
private async requestJson<T>(
path: string,
opts: RequestOptions & { method: RequestMethod }
): Promise<T> {
const response = await this.request(path, opts);
return await parseResponseJSON(response); return await parseResponseJSON(response);
} }
async get<T>(path: string, opts?: Omit<RequestOptions, 'body'>) { async get<T>(path: string, opts?: Omit<RequestOptions, 'body'>) {
return await this.request<T>(path, { ...opts, method: 'GET' }); return await this.requestJson<T>(path, { ...opts, method: 'GET' });
} }
async post<T>(path: string, opts?: RequestOptions) { async post<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'POST' }); return await this.requestJson<T>(path, { ...opts, method: 'POST' });
} }
async put<T>(path: string, opts?: RequestOptions) { async put<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PUT' }); return await this.requestJson<T>(path, { ...opts, method: 'PUT' });
} }
async patch<T>(path: string, opts?: RequestOptions) { async patch<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PATCH' }); return await this.requestJson<T>(path, { ...opts, method: 'PATCH' });
} }
async delete<T>(path: string, opts?: RequestOptions) { async delete<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'DELETE' }); return await this.requestJson<T>(path, { ...opts, method: 'DELETE' });
}
async postRaw(path: string, opts?: RequestOptions) {
return await this.request(path, { ...opts, method: 'POST' });
} }
} }