Merge pull request #5087 from gitbutlerapp/fix-Set-the-system-prompt-separately

fix: Set the system prompt for Anthropic
This commit is contained in:
Esteban Vega 2024-10-10 11:59:42 +02:00 committed by GitHub
commit 979b5620a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 21 deletions

View File

@ -1,3 +1,4 @@
import { splitPromptMessages } from './anthropicUtils';
import {
SHORT_DEFAULT_COMMIT_TEMPLATE,
SHORT_DEFAULT_BRANCH_TEMPLATE,
@ -11,29 +12,11 @@ import {
} from '$lib/ai/types';
import { andThenAsync, ok, wrapAsync, type Result } from '$lib/result';
import Anthropic from '@anthropic-ai/sdk';
import type { MessageParam, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages.mjs';
import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages.mjs';
import type { Stream } from '@anthropic-ai/sdk/streaming.mjs';
const DEFAULT_MAX_TOKENS = 1024;
function splitPromptMessages(prompt: Prompt): [MessageParam[], string | undefined] {
const messages: MessageParam[] = [];
let system: string | undefined = undefined;
for (const message of prompt) {
if (message.role === 'system') {
system = message.content;
continue;
}
messages.push({
role: message.role,
content: message.content
});
}
return [messages, system];
}
export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;

View File

@ -0,0 +1,35 @@
import { isMessageRole, type Prompt } from './types';
import { isStr } from '$lib/utils/string';
import type { MessageParam } from '@anthropic-ai/sdk/resources/messages.mjs';
export function splitPromptMessages(prompt: Prompt): [MessageParam[], string | undefined] {
const messages: MessageParam[] = [];
let system: string | undefined = undefined;
for (const message of prompt) {
if (message.role === 'system') {
system = message.content;
continue;
}
messages.push({
role: message.role,
content: message.content
});
}
return [messages, system];
}
export function messageParamToPrompt(messages: MessageParam[]): Prompt {
const result: Prompt = [];
for (const message of messages) {
if (!isStr(message.content)) continue;
if (!isMessageRole(message.role)) continue;
result.push({
role: message.role,
content: message.content
});
}
return result;
}

View File

@ -1,12 +1,28 @@
import { messageParamToPrompt, splitPromptMessages } from './anthropicUtils';
import {
SHORT_DEFAULT_BRANCH_TEMPLATE,
SHORT_DEFAULT_COMMIT_TEMPLATE,
SHORT_DEFAULT_PR_TEMPLATE
} from '$lib/ai/prompts';
import { ModelKind, type AIClient, type Prompt } from '$lib/ai/types';
import { map, type Result } from '$lib/result';
import type { AIClient, ModelKind, Prompt } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient';
function splitPromptMessagesIfNecessary(
modelKind: ModelKind,
prompt: Prompt
): [Prompt, string | undefined] {
switch (modelKind) {
case ModelKind.Anthropic: {
const [messages, system] = splitPromptMessages(prompt);
return [messageParamToPrompt(messages), system];
}
case ModelKind.OpenAI:
case ModelKind.Ollama:
return [prompt, undefined];
}
}
export class ButlerAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
@ -19,11 +35,13 @@ export class ButlerAIClient implements AIClient {
) {}
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const [messages, system] = splitPromptMessagesIfNecessary(this.modelKind, prompt);
const response = await this.cloud.postSafe<{ message: string }>(
'evaluate_prompt/predict.json',
{
body: {
messages: prompt,
messages,
system,
max_tokens: 400,
model_kind: this.modelKind
},

View File

@ -1,3 +1,4 @@
import { isStr } from '$lib/utils/string';
import type { Persisted } from '$lib/persisted/persisted';
import type { Result } from '$lib/result';
@ -30,6 +31,12 @@ export enum MessageRole {
Assistant = 'assistant'
}
export function isMessageRole(role: unknown): role is MessageRole {
if (!isStr(role)) return false;
const roles = Object.values(MessageRole);
return roles.includes(role as MessageRole);
}
export interface PromptMessage {
content: string;
role: MessageRole;