Allow multi-message prompt templates

This commit is contained in:
Caleb Owens 2024-05-15 08:09:37 +01:00
parent a4a05b0c4c
commit 6000ec6ae6
8 changed files with 216 additions and 173 deletions

View File

@ -1,24 +1,21 @@
import {
MessageRole,
type AIClient,
type AnthropicModelName,
type PromptMessage
} from '$lib/ai/types';
import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts';
import { fetch, Body } from '@tauri-apps/api/http';
import type { AIClient, AnthropicModelName, PromptMessage } from '$lib/ai/types';
type AnthropicAPIResponse = { content: { text: string }[] };
export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(
private apiKey: string,
private modelName: AnthropicModelName
) {}
async evaluate(prompt: string) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
async evaluate(prompt: PromptMessage[]) {
const body = Body.json({
messages,
messages: prompt,
max_tokens: 1024,
model: this.modelName
});

View File

@ -1,19 +1,21 @@
import { MessageRole, type ModelKind, type AIClient, type PromptMessage } from '$lib/ai/types';
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import type { AIClient, ModelKind, PromptMessage } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient';
export class ButlerAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(
private cloud: HttpClient,
private userToken: string,
private modelKind: ModelKind
) {}
async evaluate(prompt: string) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
async evaluate(prompt: PromptMessage[]) {
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
body: {
messages,
messages: prompt,
max_tokens: 400,
model_kind: this.modelKind
},

View File

@ -1,59 +1,10 @@
import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient } from '$lib/ai/types';
import { isNonEmptyObject } from '$lib/utils/typeguards';
export const DEFAULT_OLLAMA_ENDPOINT = 'http://127.0.0.1:11434';
export const DEFAULT_OLLAMA_MODEL_NAME = 'llama3';
const PROMT_EXAMPLE_COMMIT_MESSAGE_GENERATION = `Please could you write a commit message for my changes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
Only respond with the commit message.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`;
const PROMPT_EXAMPLE_BRANCH_NAME_GENERATION = `Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`;
enum OllamaAPEndpoint {
Generate = 'api/generate',
Chat = 'api/chat',
@ -91,7 +42,7 @@ interface OllamaChatMessageFormat {
result: string;
}
const OllamaChatMessageFormatSchema = {
const OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA = {
type: 'object',
properties: {
result: { type: 'string' }
@ -121,32 +72,16 @@ function isOllamaChatResponse(response: unknown): response is OllamaChatResponse
}
export class OllamaClient implements AIClient {
defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE;
constructor(
private endpoint: string,
private modelName: string
) {}
async branchName(prompt: string): Promise<string> {
const messages: PromptMessage[] = [
{
role: MessageRole.System,
content: `You are an expert in software development. Answer the given user prompts following the specified instructions.
Return your response in JSON and only use the following JSON schema:
${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}`
},
{
role: MessageRole.User,
content: PROMPT_EXAMPLE_BRANCH_NAME_GENERATION
},
{
role: MessageRole.Assistant,
content: JSON.stringify({
result: `utils-typing-is-array-of-type`
})
},
{ role: MessageRole.User, content: prompt }
];
async evaluate(prompt: PromptMessage[]) {
const messages = this.formatPrompt(prompt);
const response = await this.chat(messages);
const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
@ -156,36 +91,31 @@ ${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}`
return rawResponse.result;
}
async evaluate(prompt: string) {
const messages: PromptMessage[] = [
/**
* Appends a system message which instructs the model to respond using a particular JSON schema
* Modifies the prompt's Assistant messages to make use of the correct schema
*/
private formatPrompt(prompt: PromptMessage[]) {
const withFormattedResponses = prompt.map((promptMessage) => {
if (promptMessage.role == MessageRole.Assistant) {
return {
role: MessageRole.Assistant,
content: JSON.stringify({ result: promptMessage.content })
};
} else {
return promptMessage;
}
});
return [
{
role: MessageRole.System,
content: `You are an expert in software development. Answer the given user prompts following the specified instructions.
Return your response in JSON and only use the following JSON schema:
${JSON.stringify(OllamaChatMessageFormatSchema, null, 2)}`
${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
},
{
role: MessageRole.User,
content: PROMT_EXAMPLE_COMMIT_MESSAGE_GENERATION
},
{
role: MessageRole.Assistant,
content: JSON.stringify({
result: `Typing utilities: Check for array of type
Added an utility function to check whether a given value is an array of a specific type.`
})
},
{ role: MessageRole.User, content: prompt }
...withFormattedResponses
];
const response = await this.chat(messages);
const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
throw new Error('Invalid response: ' + response.message.content);
}
return rawResponse.result;
}
/**

View File

@ -1,22 +1,19 @@
import {
MessageRole,
type OpenAIModelName,
type PromptMessage,
type AIClient
} from '$lib/ai/types';
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import type { OpenAIModelName, PromptMessage, AIClient } from '$lib/ai/types';
import type OpenAI from 'openai';
export class OpenAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(
private modelName: OpenAIModelName,
private openAI: OpenAI
) {}
async evaluate(prompt: string) {
const messages: PromptMessage[] = [{ role: MessageRole.User, content: prompt }];
async evaluate(prompt: PromptMessage[]) {
const response = await this.openAI.chat.completions.create({
messages,
messages: prompt,
model: this.modelName,
max_tokens: 400
});

107
app/src/lib/ai/prompts.ts Normal file
View File

@ -0,0 +1,107 @@
import { type PromptMessage, MessageRole } from '$lib/ai/types';
export const SHORT_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a commit message for my changes.
Only respond with the commit message. Don't give any notes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
%{brief_style}
%{emoji_style}
Here is my git diff:
%{diff}`
}
];
export const LONG_DEFAULT_COMMIT_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a commit message for my changes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
Only respond with the commit message.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`
},
{
role: MessageRole.Assistant,
content: `Typing utilities: Check for array of type
Added an utility function to check whether a given value is an array of a specific type.`
},
...SHORT_DEFAULT_COMMIT_TEMPLATE
];
export const SHORT_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
%{diff}`
}
];
export const LONG_DEFAULT_BRANCH_TEMPLATE: PromptMessage[] = [
{
role: MessageRole.User,
content: `Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
diff --git a/src/utils/typing.ts b/src/utils/typing.ts
index 1cbfaa2..7aeebcf 100644
--- a/src/utils/typing.ts
+++ b/src/utils/typing.ts
@@ -35,3 +35,10 @@ export function isNonEmptyObject(something: unknown): something is UnknownObject
(Object.keys(something).length > 0 || Object.getOwnPropertySymbols(something).length > 0)
);
}
+
+export function isArrayOf<T>(
+ something: unknown,
+ check: (value: unknown) => value is T
+): something is T[] {
+ return Array.isArray(something) && something.every(check);
+}`
},
{
role: MessageRole.Assistant,
content: `utils-typing-is-array-of-type`
},
...SHORT_DEFAULT_BRANCH_TEMPLATE
];

View File

@ -1,8 +1,15 @@
import { AnthropicAIClient } from '$lib/ai/anthropicClient';
import { ButlerAIClient } from '$lib/ai/butlerClient';
import { OpenAIClient } from '$lib/ai/openAIClient';
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { AIService, GitAIConfigKey, KeyOption, buildDiff } from '$lib/ai/service';
import { AnthropicModelName, ModelKind, OpenAIModelName, type AIClient } from '$lib/ai/types';
import {
AnthropicModelName,
ModelKind,
OpenAIModelName,
type AIClient,
type PromptMessage
} from '$lib/ai/types';
import { HttpClient } from '$lib/backend/httpClient';
import * as toasts from '$lib/utils/toasts';
import { Hunk } from '$lib/vbranches/types';
@ -40,9 +47,11 @@ const fetchMock = vi.fn();
const cloud = new HttpClient(fetchMock);
class DummyAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(private response = 'lorem ipsum') {}
async evaluate(_prompt: string) {
async evaluate(_prompt: PromptMessage[]) {
return this.response;
}
}

View File

@ -6,7 +6,14 @@ import {
OllamaClient
} from '$lib/ai/ollamaClient';
import { OpenAIClient } from '$lib/ai/openAIClient';
import { OpenAIModelName, type AIClient, AnthropicModelName, ModelKind } from '$lib/ai/types';
import {
OpenAIModelName,
type AIClient,
AnthropicModelName,
ModelKind,
type PromptMessage,
MessageRole
} from '$lib/ai/types';
import { splitMessage } from '$lib/utils/commitMessage';
import * as toasts from '$lib/utils/toasts';
import OpenAI from 'openai';
@ -16,34 +23,6 @@ import type { Hunk } from '$lib/vbranches/types';
const maxDiffLengthLimitForAPI = 5000;
const defaultCommitTemplate = `
Please could you write a commit message for my changes.
Only respond with the commit message. Don't give any notes.
Explain what were the changes and why the changes were done.
Focus the most important changes.
Use the present tense.
Use a semantic commit prefix.
Hard wrap lines at 72 characters.
Ensure the title is only 50 characters.
Do not start any lines with the hash symbol.
%{brief_style}
%{emoji_style}
Here is my git diff:
%{diff}
`;
const defaultBranchTemplate = `
Please could you write a branch name for my changes.
A branch name represent a brief description of the changes in the diff (branch).
Branch names should contain no whitespace and instead use dashes to separate words.
Branch names should contain a maximum of 5 words.
Only respond with the branch name.
Here is my git diff:
%{diff}
`;
export enum KeyOption {
BringYourOwn = 'bringYourOwn',
ButlerAPI = 'butlerAPI'
@ -66,13 +45,13 @@ type SummarizeCommitOpts = {
hunks: Hunk[];
useEmojiStyle?: boolean;
useBriefStyle?: boolean;
commitTemplate?: string;
commitTemplate?: PromptMessage[];
userToken?: string;
};
type SummarizeBranchOpts = {
hunks: Hunk[];
branchTemplate?: string;
branchTemplate?: PromptMessage[];
userToken?: string;
};
@ -261,24 +240,37 @@ export class AIService {
hunks,
useEmojiStyle = false,
useBriefStyle = false,
commitTemplate = defaultCommitTemplate,
commitTemplate,
userToken
}: SummarizeCommitOpts) {
const aiClient = await this.buildClient(userToken);
if (!aiClient) return;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
let prompt = commitTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const defaultedCommitTemplate = commitTemplate || aiClient.defaultCommitTemplate;
const briefPart = useBriefStyle
? 'The commit message must be only one sentence and as short as possible.'
: '';
prompt = prompt.replaceAll('%{brief_style}', briefPart);
const prompt = defaultedCommitTemplate.map((promptMessage) => {
if (promptMessage.role != MessageRole.User) {
return promptMessage;
}
const emojiPart = useEmojiStyle
? 'Make use of GitMoji in the title prefix.'
: "Don't use any emoji.";
prompt = prompt.replaceAll('%{emoji_style}', emojiPart);
let content = promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const briefPart = useBriefStyle
? 'The commit message must be only one sentence and as short as possible.'
: '';
content = content.replaceAll('%{brief_style}', briefPart);
const emojiPart = useEmojiStyle
? 'Make use of GitMoji in the title prefix.'
: "Don't use any emoji.";
content = content.replaceAll('%{emoji_style}', emojiPart);
return {
role: MessageRole.User,
content
};
});
let message = await aiClient.evaluate(prompt);
@ -290,18 +282,24 @@ export class AIService {
return description ? `${title}\n\n${description}` : title;
}
async summarizeBranch({
hunks,
branchTemplate = defaultBranchTemplate,
userToken = undefined
}: SummarizeBranchOpts) {
async summarizeBranch({ hunks, branchTemplate, userToken = undefined }: SummarizeBranchOpts) {
const aiClient = await this.buildClient(userToken);
if (!aiClient) return;
const diffLengthLimit = await this.getDiffLengthLimitConsideringAPI();
const prompt = branchTemplate.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit));
const branchNamePromise = aiClient.evaluate(prompt);
const message = await branchNamePromise;
const defaultedBranchTemplate = branchTemplate || aiClient.defaultBranchTemplate;
const prompt = defaultedBranchTemplate.map((promptMessage) => {
if (promptMessage.role != MessageRole.User) {
return promptMessage;
}
return {
role: MessageRole.User,
content: promptMessage.content.replaceAll('%{diff}', buildDiff(hunks, diffLengthLimit))
};
});
const message = await aiClient.evaluate(prompt);
return message.replaceAll(' ', '-').replaceAll('\n', '-');
}
}

View File

@ -28,5 +28,8 @@ export interface PromptMessage {
}
export interface AIClient {
evaluate(prompt: string): Promise<string>;
evaluate(prompt: PromptMessage[]): Promise<string>;
defaultBranchTemplate: PromptMessage[];
defaultCommitTemplate: PromptMessage[];
}