Merge pull request #5372 from gitbutlerapp/e-branch-2

ai: Stream the Butler AI messages
This commit is contained in:
Esteban Vega 2024-11-01 14:02:55 +01:00 committed by GitHub
commit d78d84a887
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 16 deletions

View File

@ -4,8 +4,9 @@ import {
SHORT_DEFAULT_COMMIT_TEMPLATE,
SHORT_DEFAULT_PR_TEMPLATE
} from '$lib/ai/prompts';
import { ModelKind, type AIClient, type Prompt } from '$lib/ai/types';
import { map, wrapAsync, type Result } from '$lib/result';
import { ModelKind, type AIClient, type AIEvalOptions, type Prompt } from '$lib/ai/types';
import { andThenAsync, ok, wrapAsync, type Result } from '$lib/result';
import { stringStreamGenerator } from '$lib/utils/promise';
import type { HttpClient } from '@gitbutler/shared/httpClient';
function splitPromptMessagesIfNecessary(
@ -33,20 +34,33 @@ export class ButlerAIClient implements AIClient {
private modelKind: ModelKind
) {}
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
async evaluate(prompt: Prompt, options?: AIEvalOptions): Promise<Result<string, Error>> {
const [messages, system] = splitPromptMessagesIfNecessary(this.modelKind, prompt);
const response = await wrapAsync<{ message: string }, Error>(
const response = await wrapAsync<Response, Error>(
async () =>
await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
await this.cloud.postRaw('ai/stream', {
body: {
messages,
system,
max_tokens: 400,
max_tokens: 3600,
model_kind: this.modelKind
}
})
);
return map(response, ({ message }) => message);
return await andThenAsync(response, async (r) => {
const reader = r.body?.getReader();
if (!reader) {
return ok('');
}
const buffer: string[] = [];
for await (const chunk of stringStreamGenerator(reader)) {
options?.onToken?.(chunk);
buffer.push(chunk);
}
return ok(buffer.join(''));
});
}
}

View File

@ -74,12 +74,21 @@
const prompt = promptService.selectedCommitPrompt(project.id);
let firstToken = true;
const generatedMessageResult = await aiService.summarizeCommit({
hunks,
useEmojiStyle: $commitGenerationUseEmojis,
useBriefStyle: $commitGenerationExtraConcise,
commitTemplate: prompt,
branchName: $branch.name
branchName: $branch.name,
onToken: (t) => {
if (firstToken) {
commitMessage = '';
firstToken = false;
}
commitMessage += t;
}
});
if (isFailure(generatedMessageResult)) {

View File

@ -253,7 +253,7 @@
directive: aiDescriptionDirective,
commitMessages: commits.map((c) => c.description),
prBodyTemplate: templateBody,
onToken: async (t) => {
onToken: (t) => {
if (firstToken) {
inputBody = '';
firstToken = false;

View File

@ -0,0 +1,15 @@
export async function* stringStreamGenerator(
reader: ReadableStreamDefaultReader<Uint8Array>
): AsyncGenerator<string, void, void> {
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
yield new TextDecoder().decode(value);
}
} finally {
reader.releaseLock();
}
}

View File

@ -35,10 +35,10 @@ export class HttpClient {
return new URL(path, this.apiUrl);
}
private async request<T>(
private async request(
path: string,
opts: RequestOptions & { method: RequestMethod }
): Promise<T> {
): Promise<Response> {
const butlerHeaders = new Headers(DEFAULT_HEADERS);
if (opts.headers) {
@ -60,27 +60,39 @@ export class HttpClient {
body: formatBody(opts.body)
});
return response;
}
private async requestJson<T>(
path: string,
opts: RequestOptions & { method: RequestMethod }
): Promise<T> {
const response = await this.request(path, opts);
return await parseResponseJSON(response);
}
async get<T>(path: string, opts?: Omit<RequestOptions, 'body'>) {
return await this.request<T>(path, { ...opts, method: 'GET' });
return await this.requestJson<T>(path, { ...opts, method: 'GET' });
}
async post<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'POST' });
return await this.requestJson<T>(path, { ...opts, method: 'POST' });
}
async put<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PUT' });
return await this.requestJson<T>(path, { ...opts, method: 'PUT' });
}
async patch<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'PATCH' });
return await this.requestJson<T>(path, { ...opts, method: 'PATCH' });
}
async delete<T>(path: string, opts?: RequestOptions) {
return await this.request<T>(path, { ...opts, method: 'DELETE' });
return await this.requestJson<T>(path, { ...opts, method: 'DELETE' });
}
async postRaw(path: string, opts?: RequestOptions) {
return await this.request(path, { ...opts, method: 'POST' });
}
}