1
1
mirror of https://github.com/leon-ai/leon.git synced 2024-11-30 19:07:39 +03:00

refactor(server): NER

This commit is contained in:
louistiti 2023-03-20 22:04:04 +08:00
parent 6820501da6
commit a8e3fb8b87
6 changed files with 174 additions and 154 deletions

View File

@ -4,7 +4,7 @@
# Language currently used
LEON_LANG=en-US
# HttpServer
# Server
LEON_HOST=http://localhost
LEON_PORT=1337

View File

@ -26,6 +26,7 @@ import type {
NLPDomain,
NLPSkill,
NLPUtterance,
NLUResolver,
NLUResult,
NLUSlot,
NLUSlots
@ -93,8 +94,8 @@ interface IntentObject {
utterance: NLPUtterance
current_entities: NEREntity[]
entities: NEREntity[]
current_resolvers: NERCustomEntity[]
resolvers: NERCustomEntity[]
current_resolvers: NLUResolver[]
resolvers: NLUResolver[]
slots: { [key: string]: NLUSlot['value'] | undefined }
}

View File

@ -1,6 +1,13 @@
import fs from 'node:fs'
import type { NEREntity } from '@/core/nlp/types'
import type { ShortLanguageCode } from '@/types'
import type { NEREntity, NERSpacyEntity, NLPUtterance, NLUResult } from '@/core/nlp/types'
import type {
SkillConfigSchema,
SkillCustomEnumEntityTypeSchema,
SkillCustomRegexEntityTypeSchema,
SkillCustomTrimEntityTypeSchema
} from '@/schemas/skill-schemas'
import { TCP_CLIENT } from '@/core'
import { LogHelper } from '@/helpers/log-helper'
import { StringHelper } from '@/helpers/string-helper'
@ -52,20 +59,24 @@ export default class NER {
/**
* Grab entities and match them with the utterance
*/
public extractEntities(lang, utteranceSamplesFilePath, obj) {
public extractEntities(
lang: ShortLanguageCode,
skillConfigPath: string,
nluResult: NLUResult
): Promise<NEREntity[]> {
return new Promise(async (resolve) => {
LogHelper.title('NER')
LogHelper.info('Looking for entities...')
const { classification } = obj
const { classification } = nluResult
// Remove end-punctuation and add an end-whitespace
const utterance = `${StringHelper.removeEndPunctuation(obj.utterance)} `
const { actions } = JSON.parse(
fs.readFileSync(utteranceSamplesFilePath, 'utf8')
const utterance = `${StringHelper.removeEndPunctuation(nluResult.utterance)} `
const { actions }: { actions: SkillConfigSchema['actions'] } = JSON.parse(
fs.readFileSync(skillConfigPath, 'utf8')
)
const { action } = classification
const promises = []
const actionEntities = actions[action].entities || []
const actionEntities = actions[action]?.entities || []
/**
* Browse action entities
@ -74,18 +85,18 @@ export default class NER {
for (let i = 0; i < actionEntities.length; i += 1) {
const entity = actionEntities[i]
if (entity.type === 'regex') {
if (entity?.type === 'regex') {
promises.push(this.injectRegexEntity(lang, entity))
} else if (entity.type === 'trim') {
} else if (entity?.type === 'trim') {
promises.push(this.injectTrimEntity(lang, entity))
} else if (entity.type === 'enum') {
} else if (entity?.type === 'enum') {
promises.push(this.injectEnumEntity(lang, entity))
}
}
await Promise.all(promises)
const { entities } = await this.manager.process({
const { entities }: { entities: NEREntity[] } = await this.manager.process({
locale: lang,
text: utterance
})
@ -105,7 +116,6 @@ export default class NER {
})
if (entities.length > 0) {
console.log('entities', entities)
NER.logExtraction(entities)
return resolve(entities)
}
@ -119,11 +129,12 @@ export default class NER {
/**
* Get spaCy entities from the TCP server
*/
public getSpacyEntities(utterance) {
public getSpacyEntities(utterance: NLPUtterance): Promise<NERSpacyEntity[]> {
return new Promise((resolve) => {
const spacyEntitiesReceivedHandler = async ({ spacyEntities }) => {
resolve(spacyEntities)
}
const spacyEntitiesReceivedHandler =
async ({ spacyEntities }: { spacyEntities: NERSpacyEntity[] }): Promise<void> => {
resolve(spacyEntities)
}
TCP_CLIENT.ee.removeAllListeners()
TCP_CLIENT.ee.on('spacy-entities-received', spacyEntitiesReceivedHandler)
@ -135,35 +146,38 @@ export default class NER {
/**
* Inject trim type entities
*/
injectTrimEntity(lang, entity) {
private injectTrimEntity(
lang: ShortLanguageCode,
entityConfig: SkillCustomTrimEntityTypeSchema
): Promise<void> {
return new Promise((resolve) => {
for (let j = 0; j < entity.conditions.length; j += 1) {
const condition = entity.conditions[j]
for (let j = 0; j < entityConfig.conditions.length; j += 1) {
const condition = entityConfig.conditions[j]
const conditionMethod = `add${StringHelper.snakeToPascalCase(
condition.type
condition?.type || ''
)}Condition`
if (condition.type === 'between') {
if (condition?.type === 'between') {
/**
* Conditions: https://github.com/axa-group/nlp.js/blob/master/docs/v3/ner-manager.md#trim-named-entities
* e.g. list.addBetweenCondition('en', 'list', 'create a', 'list')
*/
this.manager[conditionMethod](
lang,
entity.name,
condition.from,
condition.to
entityConfig.name,
condition?.from,
condition?.to
)
} else if (condition.type.indexOf('after') !== -1) {
} else if (condition?.type.indexOf('after') !== -1) {
const rule = {
type: 'afterLast',
words: condition.from,
words: condition?.from,
options: {}
}
this.manager.addRule(lang, entity.name, 'trim', rule)
this.manager[conditionMethod](lang, entity.name, condition.from)
this.manager.addRule(lang, entityConfig.name, 'trim', rule)
this.manager[conditionMethod](lang, entityConfig.name, condition?.from)
} else if (condition.type.indexOf('before') !== -1) {
this.manager[conditionMethod](lang, entity.name, condition.to)
this.manager[conditionMethod](lang, entityConfig.name, condition.to)
}
}
@ -174,9 +188,12 @@ export default class NER {
/**
* Inject regex type entities
*/
injectRegexEntity(lang, entity) {
private injectRegexEntity(
lang: ShortLanguageCode,
entityConfig: SkillCustomRegexEntityTypeSchema
): Promise<void> {
return new Promise((resolve) => {
this.manager.addRegexRule(lang, entity.name, new RegExp(entity.regex, 'g'))
this.manager.addRegexRule(lang, entityConfig.name, new RegExp(entityConfig.regex, 'g'))
resolve()
})
@ -185,13 +202,16 @@ export default class NER {
/**
* Inject enum type entities
*/
injectEnumEntity(lang, entity) {
private injectEnumEntity(
lang: ShortLanguageCode,
entityConfig: SkillCustomEnumEntityTypeSchema
): Promise<void> {
return new Promise((resolve) => {
const { name: entityName, options } = entity
const { name: entityName, options } = entityConfig
const optionKeys = Object.keys(options)
optionKeys.forEach((optionName) => {
const { synonyms } = options[optionName]
const { synonyms } = options[optionName] as { synonyms: string[] }
this.manager.addRuleOptionTexts(lang, entityName, optionName, synonyms)
})

View File

@ -1,9 +1,3 @@
/**
* TODO:
* create a "model-loader" class
*
*/
import fs from 'node:fs'
import { join } from 'node:path'
import { spawn } from 'node:child_process'
@ -47,7 +41,7 @@ export default class NLU {
this.mainNlp = {}
this.ner = {}
this.conv = new Conversation('conv0')
this.nluResultObj = defaultNluResultObj // TODO
this.nluResult = defaultNluResultObj // TODO
LogHelper.title('NLU')
LogHelper.success('New instance')
@ -94,7 +88,7 @@ export default class NLU {
version,
utterance,
lang: BRAIN.lang,
classification: this.nluResultObj.classification
classification: this.nluResult.classification
}
})
}
@ -133,7 +127,7 @@ export default class NLU {
skillName,
`config/${BRAIN.lang}.json`
)
this.nluResultObj = {
this.nluResult = {
...defaultNluResultObj, // Reset entities, slots, etc.
slots: this.conv.activeContext.slots,
utterance,
@ -145,16 +139,16 @@ export default class NLU {
confidence: 1
}
}
this.nluResultObj.entities = await NER.extractEntities(
this.nluResult.entities = await NER.extractEntities(
BRAIN.lang,
configDataFilePath,
this.nluResultObj
this.nluResult
)
const { actions, resolvers } = JSON.parse(
fs.readFileSync(configDataFilePath, 'utf8')
)
const action = actions[this.nluResultObj.classification.action]
const action = actions[this.nluResult.classification.action]
const { name: expectedItemName, type: expectedItemType } =
action.loop.expected_item
let hasMatchingEntity = false
@ -162,7 +156,7 @@ export default class NLU {
if (expectedItemType === 'entity') {
hasMatchingEntity =
this.nluResultObj.entities.filter(
this.nluResult.entities.filter(
({ entity }) => expectedItemName === entity
).length > 0
} else if (expectedItemType.indexOf('resolver') !== -1) {
@ -204,11 +198,11 @@ export default class NLU {
) {
LogHelper.title('NLU')
LogHelper.success('Resolvers resolved:')
this.nluResultObj.resolvers = resolveResolvers(expectedItemName, intent)
this.nluResultObj.resolvers.forEach((resolver) =>
this.nluResult.resolvers = resolveResolvers(expectedItemName, intent)
this.nluResult.resolvers.forEach((resolver) =>
LogHelper.success(`${intent}: ${JSON.stringify(resolver)}`)
)
hasMatchingResolver = this.nluResultObj.resolvers.length > 0
hasMatchingResolver = this.nluResult.resolvers.length > 0
}
}
@ -221,7 +215,7 @@ export default class NLU {
}
try {
const processedData = await BRAIN.execute(this.nluResultObj)
const processedData = await BRAIN.execute(this.nluResult)
// Reprocess with the original utterance that triggered the context at first
if (processedData.core?.restart === true) {
const { originalUtterance } = this.conv.activeContext
@ -361,7 +355,7 @@ export default class NLU {
}
const [skillName, actionName] = intent.split('.')
this.nluResultObj = {
this.nluResult = {
...defaultNluResultObj, // Reset entities, slots, etc.
utterance,
answers, // For dialog action type
@ -417,28 +411,28 @@ export default class NLU {
})
}
this.nluResultObj = fallback
this.nluResult = fallback
}
LogHelper.title('NLU')
LogHelper.success(
`Intent found: ${this.nluResultObj.classification.skill}.${this.nluResultObj.classification.action} (domain: ${this.nluResultObj.classification.domain})`
`Intent found: ${this.nluResult.classification.skill}.${this.nluResult.classification.action} (domain: ${this.nluResult.classification.domain})`
)
const configDataFilePath = join(
process.cwd(),
'skills',
this.nluResultObj.classification.domain,
this.nluResultObj.classification.skill,
this.nluResult.classification.domain,
this.nluResult.classification.skill,
`config/${BRAIN.lang}.json`
)
this.nluResultObj.configDataFilePath = configDataFilePath
this.nluResult.configDataFilePath = configDataFilePath
try {
this.nluResultObj.entities = await NER.extractEntities(
this.nluResult.entities = await NER.extractEntities(
BRAIN.lang,
configDataFilePath,
this.nluResultObj
this.nluResult
)
} catch (e) {
// TODO: "!" message, just do simple generic error handler
@ -468,7 +462,7 @@ export default class NLU {
}
}
const newContextName = `${this.nluResultObj.classification.domain}.${skillName}`
const newContextName = `${this.nluResult.classification.domain}.${skillName}`
if (this.conv.activeContext.name !== newContextName) {
this.conv.cleanActiveContext()
}
@ -476,21 +470,21 @@ export default class NLU {
lang: BRAIN.lang,
slots: {},
isInActionLoop: false,
originalUtterance: this.nluResultObj.utterance,
configDataFilePath: this.nluResultObj.configDataFilePath,
actionName: this.nluResultObj.classification.action,
domain: this.nluResultObj.classification.domain,
originalUtterance: this.nluResult.utterance,
configDataFilePath: this.nluResult.configDataFilePath,
actionName: this.nluResult.classification.action,
domain: this.nluResult.classification.domain,
intent,
entities: this.nluResultObj.entities
entities: this.nluResult.entities
}
// Pass current utterance entities to the NLU result object
this.nluResultObj.currentEntities =
this.nluResult.currentEntities =
this.conv.activeContext.currentEntities
// Pass context entities to the NLU result object
this.nluResultObj.entities = this.conv.activeContext.entities
this.nluResult.entities = this.conv.activeContext.entities
try {
const processedData = await BRAIN.execute(this.nluResultObj)
const processedData = await BRAIN.execute(this.nluResult)
// Prepare next action if there is one queuing
if (processedData.nextAction) {
@ -548,7 +542,7 @@ export default class NLU {
`config/${BRAIN.lang}.json`
)
this.nluResultObj = {
this.nluResult = {
...defaultNluResultObj, // Reset entities, slots, etc.
utterance,
classification: {
@ -557,10 +551,11 @@ export default class NLU {
action: actionName
}
}
const entities = await NER.extractEntities(
BRAIN.lang,
configDataFilePath,
this.nluResultObj
this.nluResult
)
// Continue to loop for questions if a slot has been filled correctly
@ -586,7 +581,7 @@ export default class NLU {
if (!this.conv.areSlotsAllFilled()) {
BRAIN.talk(`${BRAIN.wernicke('random_context_out_of_topic')}.`)
} else {
this.nluResultObj = {
this.nluResult = {
...defaultNluResultObj, // Reset entities, slots, etc.
// Assign slots only if there is a next action
slots: this.conv.activeContext.nextAction
@ -604,7 +599,7 @@ export default class NLU {
this.conv.cleanActiveContext()
return BRAIN.execute(this.nluResultObj)
return BRAIN.execute(this.nluResult)
}
this.conv.cleanActiveContext()
@ -626,22 +621,22 @@ export default class NLU {
lang: BRAIN.lang,
slots,
isInActionLoop: false,
originalUtterance: this.nluResultObj.utterance,
configDataFilePath: this.nluResultObj.configDataFilePath,
actionName: this.nluResultObj.classification.action,
domain: this.nluResultObj.classification.domain,
originalUtterance: this.nluResult.utterance,
configDataFilePath: this.nluResult.configDataFilePath,
actionName: this.nluResult.classification.action,
domain: this.nluResult.classification.domain,
intent,
entities: this.nluResultObj.entities
entities: this.nluResult.entities
}
const notFilledSlot = this.conv.getNotFilledSlot()
// Loop for questions if a slot hasn't been filled
if (notFilledSlot) {
const { actions } = JSON.parse(
fs.readFileSync(this.nluResultObj.configDataFilePath, 'utf8')
fs.readFileSync(this.nluResult.configDataFilePath, 'utf8')
)
const [currentSlot] = actions[
this.nluResultObj.classification.action
this.nluResult.classification.action
].slots.filter(({ name }) => name === notFilledSlot.name)
SOCKET_SERVER.socket.emit('suggest', currentSlot.suggestions)
@ -660,7 +655,7 @@ export default class NLU {
* according to the wished skill action
*/
fallback(fallbacks) {
const words = this.nluResultObj.utterance.toLowerCase().split(' ')
const words = this.nluResult.utterance.toLowerCase().split(' ')
if (fallbacks.length > 0) {
LogHelper.info('Looking for fallbacks...')
@ -674,14 +669,14 @@ export default class NLU {
}
if (JSON.stringify(tmpWords) === JSON.stringify(fallbacks[i].words)) {
this.nluResultObj.entities = []
this.nluResultObj.classification.domain = fallbacks[i].domain
this.nluResultObj.classification.skill = fallbacks[i].skill
this.nluResultObj.classification.action = fallbacks[i].action
this.nluResultObj.classification.confidence = 1
this.nluResult.entities = []
this.nluResult.classification.domain = fallbacks[i].domain
this.nluResult.classification.skill = fallbacks[i].skill
this.nluResult.classification.action = fallbacks[i].action
this.nluResult.classification.confidence = 1
LogHelper.success('Fallback found')
return this.nluResultObj
return this.nluResult
}
}
}

View File

@ -27,12 +27,16 @@ export interface NLUClassification {
confidence: number
}
// TODO
export interface NLUResolver {
name: string
value: string
}
export interface NLUResult {
currentEntities: NEREntity[]
entities: NEREntity[]
currentResolvers: [] // TODO
resolvers: [] // TODO
currentResolvers: NLUResolver[]
resolvers: NLUResolver[]
slots: NLUSlots
utterance: NLPUtterance
configDataFilePath: string
@ -228,7 +232,7 @@ export interface CustomEnumEntity extends CustomEntity<'enum'> {
}
}
type GlobalEntity = CustomEnumEntity
interface CustomRegexEntity extends CustomEntity<'regex'> {
export interface CustomRegexEntity extends CustomEntity<'regex'> {
resolution: {
value: string
}

View File

@ -10,69 +10,66 @@ const skillDataTypes = [
Type.Literal('global_resolver'),
Type.Literal('entity')
]
const skillCustomEntityTypes = [
Type.Array(
Type.Object(
{
type: Type.Literal('trim'),
name: Type.String({ minLength: 1 }),
conditions: Type.Array(
Type.Object(
{
type: Type.Union([
Type.Literal('between'),
Type.Literal('after'),
Type.Literal('after_first'),
Type.Literal('after_last'),
Type.Literal('before'),
Type.Literal('before_first'),
Type.Literal('before_last')
]),
from: Type.Optional(
Type.Union([
Type.Array(Type.String({ minLength: 1 })),
Type.String({ minLength: 1 })
])
),
to: Type.Optional(
Type.Union([
Type.Array(Type.String({ minLength: 1 })),
Type.String({ minLength: 1 })
])
)
},
{ additionalProperties: false }
const skillCustomEnumEntityType = Type.Object(
{
type: Type.Literal('enum'),
name: Type.String(),
options: Type.Record(
Type.String({ minLength: 1 }),
Type.Object({
synonyms: Type.Array(Type.String({ minLength: 1 }))
})
)
},
{ additionalProperties: false }
)
const skillCustomRegexEntityType = Type.Object(
{
type: Type.Literal('regex'),
name: Type.String({ minLength: 1 }),
regex: Type.String({ minLength: 1 })
},
{ additionalProperties: false }
)
const skillCustomTrimEntityType = Type.Object(
{
type: Type.Literal('trim'),
name: Type.String({ minLength: 1 }),
conditions: Type.Array(
Type.Object(
{
type: Type.Union([
Type.Literal('between'),
Type.Literal('after'),
Type.Literal('after_first'),
Type.Literal('after_last'),
Type.Literal('before'),
Type.Literal('before_first'),
Type.Literal('before_last')
]),
from: Type.Optional(
Type.Union([
Type.Array(Type.String({ minLength: 1 })),
Type.String({ minLength: 1 })
])
),
to: Type.Optional(
Type.Union([
Type.Array(Type.String({ minLength: 1 })),
Type.String({ minLength: 1 })
])
)
)
},
{ additionalProperties: false }
},
{ additionalProperties: false }
)
)
),
Type.Array(
Type.Object(
{
type: Type.Literal('regex'),
name: Type.String({ minLength: 1 }),
regex: Type.String({ minLength: 1 })
},
{ additionalProperties: false }
)
),
Type.Array(
Type.Object(
{
type: Type.Literal('enum'),
name: Type.String(),
options: Type.Record(
Type.String({ minLength: 1 }),
Type.Object({
synonyms: Type.Array(Type.String({ minLength: 1 }))
})
)
},
{ additionalProperties: false }
)
)
},
{ additionalProperties: false }
)
const skillCustomEntityTypes = [
Type.Array(skillCustomTrimEntityType),
Type.Array(skillCustomRegexEntityType),
Type.Array(skillCustomEnumEntityType)
]
export const domainSchemaObject = Type.Strict(
@ -195,3 +192,6 @@ export type DomainSchema = Static<typeof domainSchemaObject>
export type SkillSchema = Static<typeof skillSchemaObject>
export type SkillConfigSchema = Static<typeof skillConfigSchemaObject>
export type SkillBridgeSchema = Static<typeof skillSchemaObject.bridge>
export type SkillCustomTrimEntityTypeSchema = Static<typeof skillCustomTrimEntityType>
export type SkillCustomRegexEntityTypeSchema = Static<typeof skillCustomRegexEntityType>
export type SkillCustomEnumEntityTypeSchema = Static<typeof skillCustomEnumEntityType>