mirror of
https://github.com/gitbutlerapp/gitbutler.git
synced 2024-11-27 09:47:34 +03:00
Merge pull request #5372 from gitbutlerapp/e-branch-2
ai: Stream the Butler AI messages
This commit is contained in:
commit
d78d84a887
@ -4,8 +4,9 @@ import {
|
||||
SHORT_DEFAULT_COMMIT_TEMPLATE,
|
||||
SHORT_DEFAULT_PR_TEMPLATE
|
||||
} from '$lib/ai/prompts';
|
||||
import { ModelKind, type AIClient, type Prompt } from '$lib/ai/types';
|
||||
import { map, wrapAsync, type Result } from '$lib/result';
|
||||
import { ModelKind, type AIClient, type AIEvalOptions, type Prompt } from '$lib/ai/types';
|
||||
import { andThenAsync, ok, wrapAsync, type Result } from '$lib/result';
|
||||
import { stringStreamGenerator } from '$lib/utils/promise';
|
||||
import type { HttpClient } from '@gitbutler/shared/httpClient';
|
||||
|
||||
function splitPromptMessagesIfNecessary(
|
||||
@ -33,20 +34,33 @@ export class ButlerAIClient implements AIClient {
|
||||
private modelKind: ModelKind
|
||||
) {}
|
||||
|
||||
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
|
||||
async evaluate(prompt: Prompt, options?: AIEvalOptions): Promise<Result<string, Error>> {
|
||||
const [messages, system] = splitPromptMessagesIfNecessary(this.modelKind, prompt);
|
||||
const response = await wrapAsync<{ message: string }, Error>(
|
||||
const response = await wrapAsync<Response, Error>(
|
||||
async () =>
|
||||
await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
|
||||
await this.cloud.postRaw('ai/stream', {
|
||||
body: {
|
||||
messages,
|
||||
system,
|
||||
max_tokens: 400,
|
||||
max_tokens: 3600,
|
||||
model_kind: this.modelKind
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
return map(response, ({ message }) => message);
|
||||
return await andThenAsync(response, async (r) => {
|
||||
const reader = r.body?.getReader();
|
||||
if (!reader) {
|
||||
return ok('');
|
||||
}
|
||||
|
||||
const buffer: string[] = [];
|
||||
for await (const chunk of stringStreamGenerator(reader)) {
|
||||
options?.onToken?.(chunk);
|
||||
buffer.push(chunk);
|
||||
}
|
||||
|
||||
return ok(buffer.join(''));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -74,12 +74,21 @@
|
||||
|
||||
const prompt = promptService.selectedCommitPrompt(project.id);
|
||||
|
||||
let firstToken = true;
|
||||
|
||||
const generatedMessageResult = await aiService.summarizeCommit({
|
||||
hunks,
|
||||
useEmojiStyle: $commitGenerationUseEmojis,
|
||||
useBriefStyle: $commitGenerationExtraConcise,
|
||||
commitTemplate: prompt,
|
||||
branchName: $branch.name
|
||||
branchName: $branch.name,
|
||||
onToken: (t) => {
|
||||
if (firstToken) {
|
||||
commitMessage = '';
|
||||
firstToken = false;
|
||||
}
|
||||
commitMessage += t;
|
||||
}
|
||||
});
|
||||
|
||||
if (isFailure(generatedMessageResult)) {
|
||||
|
@ -253,7 +253,7 @@
|
||||
directive: aiDescriptionDirective,
|
||||
commitMessages: commits.map((c) => c.description),
|
||||
prBodyTemplate: templateBody,
|
||||
onToken: async (t) => {
|
||||
onToken: (t) => {
|
||||
if (firstToken) {
|
||||
inputBody = '';
|
||||
firstToken = false;
|
||||
|
15
apps/desktop/src/lib/utils/promise.ts
Normal file
15
apps/desktop/src/lib/utils/promise.ts
Normal file
@ -0,0 +1,15 @@
|
||||
export async function* stringStreamGenerator(
|
||||
reader: ReadableStreamDefaultReader<Uint8Array>
|
||||
): AsyncGenerator<string, void, void> {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
yield new TextDecoder().decode(value);
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
@ -35,10 +35,10 @@ export class HttpClient {
|
||||
return new URL(path, this.apiUrl);
|
||||
}
|
||||
|
||||
private async request<T>(
|
||||
private async request(
|
||||
path: string,
|
||||
opts: RequestOptions & { method: RequestMethod }
|
||||
): Promise<T> {
|
||||
): Promise<Response> {
|
||||
const butlerHeaders = new Headers(DEFAULT_HEADERS);
|
||||
|
||||
if (opts.headers) {
|
||||
@ -60,27 +60,39 @@ export class HttpClient {
|
||||
body: formatBody(opts.body)
|
||||
});
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private async requestJson<T>(
|
||||
path: string,
|
||||
opts: RequestOptions & { method: RequestMethod }
|
||||
): Promise<T> {
|
||||
const response = await this.request(path, opts);
|
||||
return await parseResponseJSON(response);
|
||||
}
|
||||
|
||||
async get<T>(path: string, opts?: Omit<RequestOptions, 'body'>) {
|
||||
return await this.request<T>(path, { ...opts, method: 'GET' });
|
||||
return await this.requestJson<T>(path, { ...opts, method: 'GET' });
|
||||
}
|
||||
|
||||
async post<T>(path: string, opts?: RequestOptions) {
|
||||
return await this.request<T>(path, { ...opts, method: 'POST' });
|
||||
return await this.requestJson<T>(path, { ...opts, method: 'POST' });
|
||||
}
|
||||
|
||||
async put<T>(path: string, opts?: RequestOptions) {
|
||||
return await this.request<T>(path, { ...opts, method: 'PUT' });
|
||||
return await this.requestJson<T>(path, { ...opts, method: 'PUT' });
|
||||
}
|
||||
|
||||
async patch<T>(path: string, opts?: RequestOptions) {
|
||||
return await this.request<T>(path, { ...opts, method: 'PATCH' });
|
||||
return await this.requestJson<T>(path, { ...opts, method: 'PATCH' });
|
||||
}
|
||||
|
||||
async delete<T>(path: string, opts?: RequestOptions) {
|
||||
return await this.request<T>(path, { ...opts, method: 'DELETE' });
|
||||
return await this.requestJson<T>(path, { ...opts, method: 'DELETE' });
|
||||
}
|
||||
|
||||
async postRaw(path: string, opts?: RequestOptions) {
|
||||
return await this.request(path, { ...opts, method: 'POST' });
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user