2024-04-26 12:43:35 +03:00
|
|
|
import { randomBytes } from 'node:crypto';
|
|
|
|
|
|
|
|
import { INestApplication } from '@nestjs/common';
|
|
|
|
import request from 'supertest';
|
|
|
|
|
|
|
|
import {
|
|
|
|
DEFAULT_DIMENSIONS,
|
|
|
|
OpenAIProvider,
|
|
|
|
} from '../../src/plugins/copilot/providers/openai';
|
|
|
|
import {
|
|
|
|
CopilotCapability,
|
|
|
|
CopilotImageToImageProvider,
|
|
|
|
CopilotImageToTextProvider,
|
|
|
|
CopilotProviderType,
|
|
|
|
CopilotTextToEmbeddingProvider,
|
|
|
|
CopilotTextToImageProvider,
|
|
|
|
CopilotTextToTextProvider,
|
|
|
|
PromptMessage,
|
|
|
|
} from '../../src/plugins/copilot/types';
|
|
|
|
import { gql } from './common';
|
|
|
|
import { handleGraphQLError } from './utils';
|
|
|
|
|
2024-06-07 08:53:39 +03:00
|
|
|
// @ts-expect-error no error
|
2024-04-26 12:43:35 +03:00
|
|
|
export class MockCopilotTestProvider
|
|
|
|
extends OpenAIProvider
|
|
|
|
implements
|
|
|
|
CopilotTextToTextProvider,
|
|
|
|
CopilotTextToEmbeddingProvider,
|
|
|
|
CopilotTextToImageProvider,
|
|
|
|
CopilotImageToImageProvider,
|
|
|
|
CopilotImageToTextProvider
|
|
|
|
{
|
2024-06-07 08:53:39 +03:00
|
|
|
static override readonly type = CopilotProviderType.Test;
|
2024-05-16 14:09:33 +03:00
|
|
|
override readonly availableModels = [
|
|
|
|
'test',
|
2024-06-07 08:53:44 +03:00
|
|
|
'gpt-4o',
|
2024-05-23 17:27:12 +03:00
|
|
|
'fast-sdxl/image-to-image',
|
2024-05-16 14:09:33 +03:00
|
|
|
'lcm-sd15-i2i',
|
|
|
|
'clarity-upscaler',
|
|
|
|
'imageutils/rembg',
|
|
|
|
];
|
2024-04-26 12:43:35 +03:00
|
|
|
static override readonly capabilities = [
|
|
|
|
CopilotCapability.TextToText,
|
|
|
|
CopilotCapability.TextToEmbedding,
|
|
|
|
CopilotCapability.TextToImage,
|
|
|
|
CopilotCapability.ImageToImage,
|
|
|
|
CopilotCapability.ImageToText,
|
|
|
|
];
|
|
|
|
|
2024-06-07 08:53:39 +03:00
|
|
|
constructor() {
|
|
|
|
super({ apiKey: '1' });
|
2024-04-26 12:43:35 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
override getCapabilities(): CopilotCapability[] {
|
|
|
|
return MockCopilotTestProvider.capabilities;
|
|
|
|
}
|
|
|
|
|
2024-06-07 08:53:39 +03:00
|
|
|
static override assetsConfig(_config: any) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
override get type(): CopilotProviderType {
|
|
|
|
return CopilotProviderType.Test;
|
|
|
|
}
|
|
|
|
|
2024-05-14 16:05:07 +03:00
|
|
|
override async isModelAvailable(model: string): Promise<boolean> {
|
2024-04-26 12:43:35 +03:00
|
|
|
return this.availableModels.includes(model);
|
|
|
|
}
|
|
|
|
|
|
|
|
// ====== text to text ======
|
|
|
|
|
|
|
|
override async generateText(
|
|
|
|
messages: PromptMessage[],
|
|
|
|
model: string = 'test',
|
|
|
|
_options: {
|
|
|
|
temperature?: number;
|
|
|
|
maxTokens?: number;
|
|
|
|
signal?: AbortSignal;
|
|
|
|
user?: string;
|
|
|
|
} = {}
|
|
|
|
): Promise<string> {
|
|
|
|
this.checkParams({ messages, model });
|
|
|
|
return 'generate text to text';
|
|
|
|
}
|
|
|
|
|
|
|
|
override async *generateTextStream(
|
|
|
|
messages: PromptMessage[],
|
|
|
|
model: string = 'gpt-3.5-turbo',
|
|
|
|
options: {
|
|
|
|
temperature?: number;
|
|
|
|
maxTokens?: number;
|
|
|
|
signal?: AbortSignal;
|
|
|
|
user?: string;
|
|
|
|
} = {}
|
|
|
|
): AsyncIterable<string> {
|
|
|
|
this.checkParams({ messages, model });
|
|
|
|
|
|
|
|
const result = 'generate text to text stream';
|
|
|
|
for await (const message of result) {
|
|
|
|
yield message;
|
|
|
|
if (options.signal?.aborted) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// ====== text to embedding ======
|
|
|
|
|
|
|
|
override async generateEmbedding(
|
|
|
|
messages: string | string[],
|
|
|
|
model: string,
|
|
|
|
options: {
|
|
|
|
dimensions: number;
|
|
|
|
signal?: AbortSignal;
|
|
|
|
user?: string;
|
|
|
|
} = { dimensions: DEFAULT_DIMENSIONS }
|
|
|
|
): Promise<number[][]> {
|
|
|
|
messages = Array.isArray(messages) ? messages : [messages];
|
|
|
|
this.checkParams({ embeddings: messages, model });
|
|
|
|
|
|
|
|
return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)];
|
|
|
|
}
|
|
|
|
|
|
|
|
// ====== text to image ======
|
|
|
|
override async generateImages(
|
|
|
|
messages: PromptMessage[],
|
2024-05-16 14:09:33 +03:00
|
|
|
model: string = 'test',
|
2024-04-26 12:43:35 +03:00
|
|
|
_options: {
|
|
|
|
signal?: AbortSignal;
|
|
|
|
user?: string;
|
|
|
|
} = {}
|
|
|
|
): Promise<Array<string>> {
|
|
|
|
const { content: prompt } = messages.pop() || {};
|
|
|
|
if (!prompt) {
|
|
|
|
throw new Error('Prompt is required');
|
|
|
|
}
|
|
|
|
|
2024-05-16 14:09:33 +03:00
|
|
|
// just let test case can easily verify the final prompt
|
|
|
|
return [`https://example.com/${model}.jpg`, prompt];
|
2024-04-26 12:43:35 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
override async *generateImagesStream(
|
|
|
|
messages: PromptMessage[],
|
|
|
|
model: string = 'dall-e-3',
|
|
|
|
options: {
|
|
|
|
signal?: AbortSignal;
|
|
|
|
user?: string;
|
|
|
|
} = {}
|
|
|
|
): AsyncIterable<string> {
|
|
|
|
const ret = await this.generateImages(messages, model, options);
|
|
|
|
for (const url of ret) {
|
|
|
|
yield url;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
export async function createCopilotSession(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
workspaceId: string,
|
|
|
|
docId: string,
|
|
|
|
promptName: string
|
|
|
|
): Promise<string> {
|
|
|
|
const res = await request(app.getHttpServer())
|
|
|
|
.post(gql)
|
|
|
|
.auth(userToken, { type: 'bearer' })
|
|
|
|
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
|
|
|
.send({
|
|
|
|
query: `
|
|
|
|
mutation createCopilotSession($options: CreateChatSessionInput!) {
|
|
|
|
createCopilotSession(options: $options)
|
|
|
|
}
|
|
|
|
`,
|
|
|
|
variables: { options: { workspaceId, docId, promptName } },
|
|
|
|
})
|
|
|
|
.expect(200);
|
|
|
|
|
|
|
|
handleGraphQLError(res);
|
|
|
|
|
|
|
|
return res.body.data.createCopilotSession;
|
|
|
|
}
|
|
|
|
|
|
|
|
export async function createCopilotMessage(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
sessionId: string,
|
|
|
|
content?: string,
|
|
|
|
attachments?: string[],
|
|
|
|
blobs?: ArrayBuffer[],
|
|
|
|
params?: Record<string, string>
|
|
|
|
): Promise<string> {
|
|
|
|
const res = await request(app.getHttpServer())
|
|
|
|
.post(gql)
|
|
|
|
.auth(userToken, { type: 'bearer' })
|
|
|
|
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
|
|
|
.send({
|
|
|
|
query: `
|
|
|
|
mutation createCopilotMessage($options: CreateChatMessageInput!) {
|
|
|
|
createCopilotMessage(options: $options)
|
|
|
|
}
|
|
|
|
`,
|
|
|
|
variables: {
|
|
|
|
options: { sessionId, content, attachments, blobs, params },
|
|
|
|
},
|
|
|
|
})
|
|
|
|
.expect(200);
|
|
|
|
|
|
|
|
handleGraphQLError(res);
|
|
|
|
|
|
|
|
return res.body.data.createCopilotMessage;
|
|
|
|
}
|
|
|
|
|
|
|
|
export async function chatWithText(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
sessionId: string,
|
2024-05-15 14:02:37 +03:00
|
|
|
messageId?: string,
|
2024-04-26 12:43:35 +03:00
|
|
|
prefix = ''
|
|
|
|
): Promise<string> {
|
2024-05-15 14:02:37 +03:00
|
|
|
const query = messageId ? `?messageId=${messageId}` : '';
|
2024-04-26 12:43:35 +03:00
|
|
|
const res = await request(app.getHttpServer())
|
2024-05-15 14:02:37 +03:00
|
|
|
.get(`/api/copilot/chat/${sessionId}${prefix}${query}`)
|
2024-04-26 12:43:35 +03:00
|
|
|
.auth(userToken, { type: 'bearer' })
|
|
|
|
.expect(200);
|
|
|
|
|
|
|
|
return res.text;
|
|
|
|
}
|
|
|
|
|
|
|
|
export async function chatWithTextStream(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
sessionId: string,
|
2024-05-15 14:02:37 +03:00
|
|
|
messageId?: string
|
2024-04-26 12:43:35 +03:00
|
|
|
) {
|
|
|
|
return chatWithText(app, userToken, sessionId, messageId, '/stream');
|
|
|
|
}
|
|
|
|
|
2024-06-07 08:53:44 +03:00
|
|
|
export async function chatWithWorkflow(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
sessionId: string,
|
|
|
|
messageId?: string
|
|
|
|
) {
|
|
|
|
return chatWithText(app, userToken, sessionId, messageId, '/workflow');
|
|
|
|
}
|
|
|
|
|
2024-04-26 12:43:35 +03:00
|
|
|
export async function chatWithImages(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
sessionId: string,
|
2024-05-15 14:02:37 +03:00
|
|
|
messageId?: string
|
2024-04-26 12:43:35 +03:00
|
|
|
) {
|
|
|
|
return chatWithText(app, userToken, sessionId, messageId, '/images');
|
|
|
|
}
|
|
|
|
|
|
|
|
export function textToEventStream(
|
|
|
|
content: string | string[],
|
|
|
|
id: string,
|
|
|
|
event = 'message'
|
|
|
|
): string {
|
|
|
|
return (
|
|
|
|
Array.from(content)
|
|
|
|
.map(x => `\nevent: ${event}\nid: ${id}\ndata: ${x}`)
|
|
|
|
.join('\n') + '\n\n'
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatMessage = {
|
|
|
|
role: string;
|
|
|
|
content: string;
|
|
|
|
attachments: string[] | null;
|
|
|
|
createdAt: string;
|
|
|
|
};
|
|
|
|
|
|
|
|
type History = {
|
|
|
|
sessionId: string;
|
|
|
|
tokens: number;
|
|
|
|
action: string | null;
|
|
|
|
createdAt: string;
|
|
|
|
messages: ChatMessage[];
|
|
|
|
};
|
|
|
|
|
|
|
|
export async function getHistories(
|
|
|
|
app: INestApplication,
|
|
|
|
userToken: string,
|
|
|
|
variables: {
|
|
|
|
workspaceId: string;
|
|
|
|
docId?: string;
|
|
|
|
options?: {
|
|
|
|
sessionId?: string;
|
|
|
|
action?: boolean;
|
|
|
|
limit?: number;
|
|
|
|
skip?: number;
|
|
|
|
};
|
|
|
|
}
|
|
|
|
): Promise<History[]> {
|
|
|
|
const res = await request(app.getHttpServer())
|
|
|
|
.post(gql)
|
|
|
|
.auth(userToken, { type: 'bearer' })
|
|
|
|
.set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
|
|
|
|
.send({
|
|
|
|
query: `
|
|
|
|
query getCopilotHistories(
|
|
|
|
$workspaceId: String!
|
|
|
|
$docId: String
|
|
|
|
$options: QueryChatHistoriesInput
|
|
|
|
) {
|
|
|
|
currentUser {
|
|
|
|
copilot(workspaceId: $workspaceId) {
|
|
|
|
histories(docId: $docId, options: $options) {
|
|
|
|
sessionId
|
|
|
|
tokens
|
|
|
|
action
|
|
|
|
createdAt
|
|
|
|
messages {
|
|
|
|
role
|
|
|
|
content
|
|
|
|
attachments
|
|
|
|
createdAt
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
`,
|
|
|
|
variables,
|
|
|
|
})
|
|
|
|
.expect(200);
|
|
|
|
|
|
|
|
handleGraphQLError(res);
|
|
|
|
|
|
|
|
return res.body.data.currentUser?.copilot?.histories || [];
|
|
|
|
}
|