Pass in hunk array rather than diff

This commit is contained in:
Caleb Owens 2024-03-20 22:04:46 +00:00
parent a4a2a4fc67
commit bffcdbdae4
4 changed files with 99 additions and 34 deletions

View File

@ -8,9 +8,12 @@ import {
GitAIConfigKey,
KeyOption,
ModelKind,
OpenAIModelName
OpenAIModelName,
buildDiff
} from '$lib/backend/aiService';
import * as toasts from '$lib/utils/toasts';
import { Hunk } from '$lib/vbranches/types';
import { plainToInstance } from 'class-transformer';
import { expect, test, describe, vi } from 'vitest';
import type { AIClient } from '$lib/backend/aiClient';
import type { GitConfigService } from '$lib/backend/gitConfigService';
@ -52,7 +55,7 @@ class DummyAIClient implements AIClient {
}
}
const examplePatch = `
const diff1 = `
@@ -52,7 +52,8 @@
export enum AnthropicModelName {
@ -65,6 +68,38 @@ const examplePatch = `
export const AI_SERVICE_CONTEXT = Symbol();
`;
const hunk1 = plainToInstance(Hunk, {
id: 'asdf',
diff: diff1,
modifiedAt: new Date().toISOString(),
filePath: 'foo/bar/baz.ts',
locked: false,
lockedTo: undefined,
changeType: 'added'
});
const diff2 = `
@@ -52,7 +52,8 @@
}
async function commit() {
console.log('quack quack goes the dog');
+ const message = concatMessage(title, description);
isCommitting = true;
try {
`;
const hunk2 = plainToInstance(Hunk, {
id: 'asdf',
diff: diff2,
modifiedAt: new Date().toISOString(),
filePath: 'random.ts',
locked: false,
lockedTo: undefined,
changeType: 'added'
});
const exampleHunks = [hunk1, hunk2];
function buildDefaultAIService() {
const gitConfig = new DummyGitConfigService(structuredClone(defaultGitConfig));
return new AIService(gitConfig, cloud);
@ -149,7 +184,7 @@ describe.concurrent('AIService', () => {
vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)());
expect(await aiService.summarizeCommit({ diff: examplePatch })).toBe(undefined);
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe(undefined);
});
test('When the AI returns a single line commit message, it returns it unchanged', async () => {
@ -161,7 +196,7 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeCommit({ diff: examplePatch })).toBe('single line commit');
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('single line commit');
});
test('When the AI returns a title and body that is split by a single new line, it replaces it with two', async () => {
@ -173,7 +208,7 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeCommit({ diff: examplePatch })).toBe('one\n\nnew line');
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('one\n\nnew line');
});
test('When the commit is in brief mode, When the AI returns a title and body, it takes just the title', async () => {
@ -185,7 +220,7 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeCommit({ diff: examplePatch, useBriefStyle: true })).toBe(
expect(await aiService.summarizeCommit({ hunks: exampleHunks, useBriefStyle: true })).toBe(
'one'
);
});
@ -197,7 +232,7 @@ describe.concurrent('AIService', () => {
vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)());
expect(await aiService.summarizeBranch({ diff: examplePatch })).toBe(undefined);
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(undefined);
});
test('When the AI client returns a string with spaces, it replaces them with hypens', async () => {
@ -209,7 +244,7 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeBranch({ diff: examplePatch })).toBe('with-spaces-included');
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe('with-spaces-included');
});
test('When the AI client returns multiple lines, it replaces them with hypens', async () => {
@ -221,7 +256,7 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeBranch({ diff: examplePatch })).toBe(
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(
'with-new-lines-included'
);
});
@ -235,9 +270,32 @@ describe.concurrent('AIService', () => {
(async () => new DummyAIClient(clientResponse))()
);
expect(await aiService.summarizeBranch({ diff: examplePatch })).toBe(
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(
'with-new-lines-included'
);
});
});
});
describe.concurrent('buildDiff', () => {
test('When provided one hunk, it returns the formatted diff', () => {
const expectedOutput = `${hunk1.filePath} - ${hunk1.diff}`;
expect(buildDiff([hunk1], 10000)).to.eq(expectedOutput);
});
test('When provided one hunk and its longer than the limit, it returns the truncated formatted diff', () => {
expect(buildDiff([hunk1], 100).length).to.eq(100);
});
test('When provided multiple hunks, it joins them together with newlines', () => {
const expectedOutput1 = `${hunk1.filePath} - ${hunk1.diff}\n${hunk2.filePath} - ${hunk2.diff}`;
const expectedOutput2 = `${hunk2.filePath} - ${hunk2.diff}\n${hunk1.filePath} - ${hunk1.diff}`;
const outputMatchesExpectedValue = [expectedOutput1, expectedOutput2].includes(
buildDiff([hunk1, hunk1], 10000)
);
expect(outputMatchesExpectedValue).toBeTruthy;
});
});

View File

@ -7,8 +7,9 @@ import OpenAI from 'openai';
import type { AIClient } from '$lib/backend/aiClient';
import type { CloudClient } from '$lib/backend/cloud';
import type { GitConfigService } from '$lib/backend/gitConfigService';
import type { Hunk } from '$lib/vbranches/types';
const diffLengthLimit = 20000;
const diffLengthLimit = 5000;
const defaultCommitTemplate = `
Please could you write a commit message for my changes.
@ -71,7 +72,7 @@ export enum GitAIConfigKey {
}
type SummarizeCommitOpts = {
diff: string;
hunks: Hunk[];
useEmojiStyle?: boolean;
useBriefStyle?: boolean;
commitTemplate?: string;
@ -79,11 +80,25 @@ type SummarizeCommitOpts = {
};
type SummarizeBranchOpts = {
diff: string;
hunks: Hunk[];
branchTemplate?: string;
userToken?: string;
};
// Exported for testing only
export function buildDiff(hunks: Hunk[], limit: number) {
return shuffle(hunks.map((h) => `${h.filePath} - ${h.diff}`))
.join('\n')
.slice(0, limit);
}
function shuffle<T>(items: T[]): T[] {
return items
.map((item) => ({ item, value: Math.random() }))
.sort()
.map((item) => item.item);
}
export class AIService {
constructor(
private gitConfig: GitConfigService,
@ -183,7 +198,7 @@ export class AIService {
}
async summarizeCommit({
diff,
hunks,
useEmojiStyle = false,
useBriefStyle = false,
commitTemplate = defaultCommitTemplate,
@ -192,7 +207,7 @@ export class AIService {
const aiClient = await this.buildClient(userToken);
if (!aiClient) return;
let prompt = commitTemplate.replaceAll('%{diff}', diff.slice(0, diffLengthLimit));
let prompt = commitTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const briefPart = useBriefStyle
? 'The commit message must be only one sentence and as short as possible.'
@ -215,14 +230,14 @@ export class AIService {
}
async summarizeBranch({
diff,
hunks,
branchTemplate = defaultBranchTemplate,
userToken = undefined
}: SummarizeBranchOpts) {
const aiClient = await this.buildClient(userToken);
if (!aiClient) return;
const prompt = branchTemplate.replaceAll('%{diff}', diff.slice(0, diffLengthLimit));
const prompt = branchTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const message = await aiClient.evaluate(prompt);
return message.replaceAll(' ', '-').replaceAll('\n', '-');
}

View File

@ -83,16 +83,13 @@
async function generateBranchName() {
if (!aiGenEnabled) return;
const diff = branch.files
.map((f) => f.hunks)
.flat()
.map((h) => h.diff)
.flat()
.join('\n')
.slice(0, 5000);
const hunks = branch.files.flatMap((f) => f.hunks);
try {
const message = await aiService.summarizeBranch({ diff, userToken: $user?.access_token });
const message = await aiService.summarizeBranch({
hunks,
userToken: $user?.access_token
});
if (message && message !== branch.name) {
branch.name = message;

View File

@ -90,14 +90,9 @@
}
async function generateCommitMessage(files: LocalFile[]) {
const diff = files
.map((f) => f.hunks.filter((h) => $selectedOwnership.containsHunk(f.id, h.id)))
.flat()
.map((h) => h.diff)
.flat()
.join('\n')
.slice(0, 5000);
const hunks = files.flatMap((f) =>
f.hunks.filter((h) => $selectedOwnership.containsHunk(f.id, h.id))
);
// Branches get their names generated only if there are at least 4 lines of code
// If the change is a 'one-liner', the branch name is either left as "virtual branch"
// or the user has to manually trigger the name generation from the meatball menu
@ -109,7 +104,7 @@
aiLoading = true;
try {
const generatedMessage = await aiService.summarizeCommit({
diff,
hunks,
useEmojiStyle: $commitGenerationUseEmojis,
useBriefStyle: $commitGenerationExtraConcise,
userToken: $user?.access_token