diff --git a/.env.sample b/.env.sample index a9ed338d..f1dfa190 100644 --- a/.env.sample +++ b/.env.sample @@ -14,8 +14,10 @@ LEON_LLM=true LEON_LLM_PROVIDER=local # LLM provider API key (if not local) LEON_LLM_PROVIDER_API_KEY= -# Enable/disable LLM Natural Language Generation +# Enable/disable LLM natural language generation LEON_LLM_NLG=true +# Enable/disable LLM Action Recognition +LEON_LLM_ACTION_RECOGNITION=true # Time zone (current one by default) LEON_TIME_ZONE= diff --git a/server/src/core/http-server/api/llm-inference/post.ts b/server/src/core/http-server/api/llm-inference/post.ts index 3b986027..e6a8572d 100644 --- a/server/src/core/http-server/api/llm-inference/post.ts +++ b/server/src/core/http-server/api/llm-inference/post.ts @@ -7,6 +7,7 @@ import { SummarizationLLMDuty } from '@/core/llm-manager/llm-duties/summarizatio import { TranslationLLMDuty } from '@/core/llm-manager/llm-duties/translation-llm-duty' import { ParaphraseLLMDuty } from '@/core/llm-manager/llm-duties/paraphrase-llm-duty' import { ChitChatLLMDuty } from '@/core/llm-manager/llm-duties/chit-chat-llm-duty' +import { ActionRecognitionLLMDuty } from '@/core/llm-manager/llm-duties/action-recognition-llm-duty' import { LLM_MANAGER } from '@/core' interface PostLLMInferenceSchema { @@ -19,6 +20,7 @@ interface PostLLMInferenceSchema { } const LLM_DUTIES_MAP = { + [LLMDuties.ActionRecognition]: ActionRecognitionLLMDuty, [LLMDuties.CustomNER]: CustomNERLLMDuty, [LLMDuties.Summarization]: SummarizationLLMDuty, [LLMDuties.Translation]: TranslationLLMDuty, diff --git a/server/src/core/llm-manager/llm-duties/action-recognition-llm-duty.ts b/server/src/core/llm-manager/llm-duties/action-recognition-llm-duty.ts new file mode 100644 index 00000000..d2060889 --- /dev/null +++ b/server/src/core/llm-manager/llm-duties/action-recognition-llm-duty.ts @@ -0,0 +1,81 @@ +import { + type LLMDutyParams, + type LLMDutyResult, + LLMDuty +} from '@/core/llm-manager/llm-duty' +import { LogHelper } from '@/helpers/log-helper' +import { LLM_MANAGER, LLM_PROVIDER } from '@/core' +import { LLM_THREADS } from '@/core/llm-manager/llm-manager' +import { LLMProviders } from '@/core/llm-manager/types' +import { LLM_PROVIDER as LLM_PROVIDER_NAME } from '@/constants' + +interface ActionRecognitionLLMDutyParams extends LLMDutyParams {} + +const JSON_KEY_RESPONSE = 'action_name' + +export class ActionRecognitionLLMDuty extends LLMDuty { + protected readonly systemPrompt = `You are an AI expert in intent classification and matching. +You look up every utterance sample and description. Then you return the most probable intent (action) to be triggered based on a given utterance. +If the intent is not listed, do not make it up yourself. Instead you must return { "${JSON_KEY_RESPONSE}": "not_found" }.` + protected readonly name = 'Action Recognition LLM Duty' + protected input: LLMDutyParams['input'] = null + + constructor(params: ActionRecognitionLLMDutyParams) { + super() + + LogHelper.title(this.name) + LogHelper.success('New instance') + + this.input = params.input + } + + public async execute(): Promise { + LogHelper.title(this.name) + LogHelper.info('Executing...') + + try { + const prompt = `Utterance: "${this.input}"` + const completionParams = { + systemPrompt: this.systemPrompt, + data: { + [JSON_KEY_RESPONSE]: { + type: 'string' + } + } + } + let completionResult + + if (LLM_PROVIDER_NAME === LLMProviders.Local) { + const { LlamaChatSession } = await Function( + 'return import("node-llama-cpp")' + )() + + const context = await LLM_MANAGER.model.createContext({ + threads: LLM_THREADS + }) + const session = new LlamaChatSession({ + contextSequence: context.getSequence(), + systemPrompt: completionParams.systemPrompt + }) + + completionResult = await LLM_PROVIDER.prompt(prompt, { + ...completionParams, + session, + maxTokens: context.contextSize + }) + } else { + completionResult = await LLM_PROVIDER.prompt(prompt, completionParams) + } + + LogHelper.title(this.name) + LogHelper.success(`Duty executed: ${JSON.stringify(completionResult)}`) + + return completionResult as unknown as LLMDutyResult + } catch (e) { + LogHelper.title(this.name) + LogHelper.error(`Failed to execute: ${e}`) + } + + return null + } +} diff --git a/server/src/core/llm-manager/types.ts b/server/src/core/llm-manager/types.ts index 28a1d8a4..8ebe607c 100644 --- a/server/src/core/llm-manager/types.ts +++ b/server/src/core/llm-manager/types.ts @@ -4,6 +4,7 @@ import type { LLMDuty } from '@/core/llm-manager/llm-duty' import type { MessageLog } from '@/types' export enum LLMDuties { + ActionRecognition = 'action-recognition', CustomNER = 'customer-ner', Translation = 'translation', Summarization = 'summarization',