import { randomBytes } from 'node:crypto'; import { INestApplication } from '@nestjs/common'; import request from 'supertest'; import { DEFAULT_DIMENSIONS, OpenAIProvider, } from '../../src/plugins/copilot/providers/openai'; import { CopilotCapability, CopilotChatOptions, CopilotEmbeddingOptions, CopilotImageToImageProvider, CopilotImageToTextProvider, CopilotProviderType, CopilotTextToEmbeddingProvider, CopilotTextToImageProvider, CopilotTextToTextProvider, PromptMessage, } from '../../src/plugins/copilot/types'; import { NodeExecutorType } from '../../src/plugins/copilot/workflow/executor'; import { WorkflowGraph, WorkflowNodeType, WorkflowParams, } from '../../src/plugins/copilot/workflow/types'; import { gql } from './common'; import { handleGraphQLError } from './utils'; // @ts-expect-error no error export class MockCopilotTestProvider extends OpenAIProvider implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider, CopilotTextToImageProvider, CopilotImageToImageProvider, CopilotImageToTextProvider { static override readonly type = CopilotProviderType.Test; override readonly availableModels = [ 'test', 'gpt-4o', 'fast-sdxl/image-to-image', 'lcm-sd15-i2i', 'clarity-upscaler', 'imageutils/rembg', ]; static override readonly capabilities = [ CopilotCapability.TextToText, CopilotCapability.TextToEmbedding, CopilotCapability.TextToImage, CopilotCapability.ImageToImage, CopilotCapability.ImageToText, ]; constructor() { super({ apiKey: '1' }); } override getCapabilities(): CopilotCapability[] { return MockCopilotTestProvider.capabilities; } static override assetsConfig(_config: any) { return true; } override get type(): CopilotProviderType { return CopilotProviderType.Test; } override async isModelAvailable(model: string): Promise { return this.availableModels.includes(model); } // ====== text to text ====== override async generateText( messages: PromptMessage[], model: string = 'test', options: CopilotChatOptions = {} ): Promise { this.checkParams({ messages, model, options }); return 'generate text to text'; } override async *generateTextStream( messages: PromptMessage[], model: string = 'gpt-3.5-turbo', options: CopilotChatOptions = {} ): AsyncIterable { this.checkParams({ messages, model, options }); const result = 'generate text to text stream'; for await (const message of result) { yield message; if (options.signal?.aborted) { break; } } } // ====== text to embedding ====== override async generateEmbedding( messages: string | string[], model: string, options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { messages = Array.isArray(messages) ? messages : [messages]; this.checkParams({ embeddings: messages, model, options }); return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)]; } // ====== text to image ====== override async generateImages( messages: PromptMessage[], model: string = 'test', _options: { signal?: AbortSignal; user?: string; } = {} ): Promise> { const { content: prompt } = messages[0] || {}; if (!prompt) { throw new Error('Prompt is required'); } // just let test case can easily verify the final prompt return [`https://example.com/${model}.jpg`, prompt]; } override async *generateImagesStream( messages: PromptMessage[], model: string = 'dall-e-3', options: { signal?: AbortSignal; user?: string; } = {} ): AsyncIterable { const ret = await this.generateImages(messages, model, options); for (const url of ret) { yield url; } } } export async function createCopilotSession( app: INestApplication, userToken: string, workspaceId: string, docId: string, promptName: string ): Promise { const res = await request(app.getHttpServer()) .post(gql) .auth(userToken, { type: 'bearer' }) .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) .send({ query: ` mutation createCopilotSession($options: CreateChatSessionInput!) { createCopilotSession(options: $options) } `, variables: { options: { workspaceId, docId, promptName } }, }) .expect(200); handleGraphQLError(res); return res.body.data.createCopilotSession; } export async function createCopilotMessage( app: INestApplication, userToken: string, sessionId: string, content?: string, attachments?: string[], blobs?: ArrayBuffer[], params?: Record ): Promise { const res = await request(app.getHttpServer()) .post(gql) .auth(userToken, { type: 'bearer' }) .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) .send({ query: ` mutation createCopilotMessage($options: CreateChatMessageInput!) { createCopilotMessage(options: $options) } `, variables: { options: { sessionId, content, attachments, blobs, params }, }, }) .expect(200); handleGraphQLError(res); return res.body.data.createCopilotMessage; } export async function chatWithText( app: INestApplication, userToken: string, sessionId: string, messageId?: string, prefix = '' ): Promise { const query = messageId ? `?messageId=${messageId}` : ''; const res = await request(app.getHttpServer()) .get(`/api/copilot/chat/${sessionId}${prefix}${query}`) .auth(userToken, { type: 'bearer' }) .expect(200); return res.text; } export async function chatWithTextStream( app: INestApplication, userToken: string, sessionId: string, messageId?: string ) { return chatWithText(app, userToken, sessionId, messageId, '/stream'); } export async function chatWithWorkflow( app: INestApplication, userToken: string, sessionId: string, messageId?: string ) { return chatWithText(app, userToken, sessionId, messageId, '/workflow'); } export async function chatWithImages( app: INestApplication, userToken: string, sessionId: string, messageId?: string ) { return chatWithText(app, userToken, sessionId, messageId, '/images'); } export function sse2array(eventSource: string) { const blocks = eventSource.replace(/^\n(.*?)\n$/, '$1').split(/\n\n+/); return blocks.map(block => block.split('\n').reduce( (prev, curr) => { const [key, ...values] = curr.split(': '); return Object.assign(prev, { [key]: values.join(': ') }); }, {} as Record ) ); } export function array2sse(blocks: Record[]) { return blocks .map( e => '\n' + Object.entries(e) .filter(([k]) => !!k) .map(([k, v]) => `${k}: ${v}`) .join('\n') ) .join('\n'); } export function textToEventStream( content: string | string[], id: string, event = 'message' ): string { return ( Array.from(content) .map(x => `\nevent: ${event}\nid: ${id}\ndata: ${x}`) .join('\n') + '\n\n' ); } type ChatMessage = { role: string; content: string; attachments: string[] | null; createdAt: string; }; type History = { sessionId: string; tokens: number; action: string | null; createdAt: string; messages: ChatMessage[]; }; export async function getHistories( app: INestApplication, userToken: string, variables: { workspaceId: string; docId?: string; options?: { sessionId?: string; action?: boolean; limit?: number; skip?: number; }; } ): Promise { const res = await request(app.getHttpServer()) .post(gql) .auth(userToken, { type: 'bearer' }) .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) .send({ query: ` query getCopilotHistories( $workspaceId: String! $docId: String $options: QueryChatHistoriesInput ) { currentUser { copilot(workspaceId: $workspaceId) { histories(docId: $docId, options: $options) { sessionId tokens action createdAt messages { role content attachments createdAt } } } } } `, variables, }) .expect(200); handleGraphQLError(res); return res.body.data.currentUser?.copilot?.histories || []; } type Prompt = { name: string; model: string; messages: PromptMessage[] }; type WorkflowTestCase = { graph: WorkflowGraph; prompts: Prompt[]; callCount: number[]; input: string[]; params: WorkflowParams[]; result: (string | undefined)[]; }; export const WorkflowTestCases: WorkflowTestCase[] = [ { prompts: [ { name: 'test1', model: 'test', messages: [{ role: 'user', content: '{{content}}' }], }, ], graph: { name: 'test chat text node', graph: [ { id: 'start', name: 'test chat text node', nodeType: WorkflowNodeType.Basic, type: NodeExecutorType.ChatText, promptName: 'test1', edges: [], }, ], }, callCount: [1], input: ['test'], params: [], result: ['generate text to text stream'], }, { prompts: [], graph: { name: 'test check json node', graph: [ { id: 'start', name: 'basic node', nodeType: WorkflowNodeType.Basic, type: NodeExecutorType.CheckJson, edges: [], }, ], }, callCount: [1, 1], input: ['{"test": "true"}', '{"test": '], params: [], result: ['true', 'false'], }, { prompts: [], graph: { name: 'test check html node', graph: [ { id: 'start', name: 'basic node', nodeType: WorkflowNodeType.Basic, type: NodeExecutorType.CheckHtml, edges: [], }, ], }, callCount: [1, 1, 1, 1], params: [{}, { strict: 'true' }, {}, {}], input: [ '', '', '', '{"test": "true"}', ], result: ['true', 'false', 'true', 'false'], }, { prompts: [], graph: { name: 'test nope node', graph: [ { id: 'start', name: 'nope node', nodeType: WorkflowNodeType.Nope, edges: [], }, ], }, callCount: [1], input: ['test'], params: [], result: ['test'], }, ];