mirror of
https://github.com/leon-ai/leon.git
synced 2025-01-03 06:06:06 +03:00
feat: create new NLP skills resolvers model + NLP global resolvers model
This commit is contained in:
parent
29c5348b5a
commit
602604e437
@ -32,7 +32,7 @@ export default (lang, nlp) => new Promise(async (resolve) => {
|
||||
const resolver = resolvers[resolverName]
|
||||
const intentKeys = Object.keys(resolver.intents)
|
||||
|
||||
log.info(`[${lang}] Training "${resolverName}" resolver...`)
|
||||
log.info(`[${lang}] Training ${skillName} "${resolverName}" resolver...`)
|
||||
|
||||
intentKeys.forEach((intentName) => {
|
||||
const intent = `resolver.${currentSkill.name}.${resolverName}.${intentName}`
|
||||
@ -50,7 +50,7 @@ export default (lang, nlp) => new Promise(async (resolve) => {
|
||||
})
|
||||
})
|
||||
|
||||
log.success(`[${lang}] "${resolverName}" resolver trained`)
|
||||
log.success(`[${lang}] ${skillName} "${resolverName}" resolver trained`)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -18,25 +18,42 @@ dotenv.config()
|
||||
* npm run train [en or fr]
|
||||
*/
|
||||
export default () => new Promise(async (resolve, reject) => {
|
||||
const resolversModelFileName = 'core/data/models/leon-resolvers-model.nlp'
|
||||
const globalResolversModelFileName = 'core/data/models/leon-global-resolvers-model.nlp'
|
||||
const skillsResolversModelFileName = 'core/data/models/leon-skills-resolvers-model.nlp'
|
||||
const mainModelFileName = 'core/data/models/leon-main-model.nlp'
|
||||
|
||||
try {
|
||||
/**
|
||||
* Resolvers NLP model configuration
|
||||
* Global resolvers NLP model configuration
|
||||
*/
|
||||
const resolversContainer = await containerBootstrap()
|
||||
const globalResolversContainer = await containerBootstrap()
|
||||
|
||||
resolversContainer.use(Nlp)
|
||||
resolversContainer.use(LangAll)
|
||||
globalResolversContainer.use(Nlp)
|
||||
globalResolversContainer.use(LangAll)
|
||||
|
||||
const resolversNlp = resolversContainer.get('nlp')
|
||||
const resolversNluManager = resolversContainer.get('nlu-manager')
|
||||
const globalResolversNlp = globalResolversContainer.get('nlp')
|
||||
const globalResolversNluManager = globalResolversContainer.get('nlu-manager')
|
||||
|
||||
resolversNluManager.settings.log = false
|
||||
resolversNluManager.settings.trainByDomain = true
|
||||
resolversNlp.settings.modelFileName = resolversModelFileName
|
||||
resolversNlp.settings.threshold = 0.8
|
||||
globalResolversNluManager.settings.log = false
|
||||
globalResolversNluManager.settings.trainByDomain = false
|
||||
globalResolversNlp.settings.modelFileName = globalResolversModelFileName
|
||||
globalResolversNlp.settings.threshold = 0.8
|
||||
|
||||
/**
|
||||
* Skills resolvers NLP model configuration
|
||||
*/
|
||||
const skillsResolversContainer = await containerBootstrap()
|
||||
|
||||
skillsResolversContainer.use(Nlp)
|
||||
skillsResolversContainer.use(LangAll)
|
||||
|
||||
const skillsResolversNlp = skillsResolversContainer.get('nlp')
|
||||
const skillsResolversNluManager = skillsResolversContainer.get('nlu-manager')
|
||||
|
||||
skillsResolversNluManager.settings.log = false
|
||||
skillsResolversNluManager.settings.trainByDomain = true
|
||||
skillsResolversNlp.settings.modelFileName = skillsResolversModelFileName
|
||||
skillsResolversNlp.settings.threshold = 0.8
|
||||
|
||||
/**
|
||||
* Main NLP model configuration
|
||||
@ -66,11 +83,13 @@ export default () => new Promise(async (resolve, reject) => {
|
||||
for (let h = 0; h < shortLangs.length; h += 1) {
|
||||
const lang = shortLangs[h]
|
||||
|
||||
resolversNlp.addLanguage(lang)
|
||||
globalResolversNlp.addLanguage(lang)
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await trainGlobalResolvers(lang, resolversNlp)
|
||||
await trainGlobalResolvers(lang, globalResolversNlp)
|
||||
|
||||
skillsResolversNlp.addLanguage(lang)
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await trainSkillsResolvers(lang, resolversNlp)
|
||||
await trainSkillsResolvers(lang, skillsResolversNlp)
|
||||
|
||||
mainNlp.addLanguage(lang)
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
@ -80,12 +99,22 @@ export default () => new Promise(async (resolve, reject) => {
|
||||
}
|
||||
|
||||
try {
|
||||
await resolversNlp.train()
|
||||
await globalResolversNlp.train()
|
||||
|
||||
log.success(`Resolvers NLP model saved in ${resolversModelFileName}`)
|
||||
log.success(`Global resolvers NLP model saved in ${globalResolversModelFileName}`)
|
||||
resolve()
|
||||
} catch (e) {
|
||||
log.error(`Failed to save resolvers NLP model: ${e}`)
|
||||
log.error(`Failed to save global resolvers NLP model: ${e}`)
|
||||
reject()
|
||||
}
|
||||
|
||||
try {
|
||||
await skillsResolversNlp.train()
|
||||
|
||||
log.success(`Skills resolvers NLP model saved in ${skillsResolversModelFileName}`)
|
||||
resolve()
|
||||
} catch (e) {
|
||||
log.error(`Failed to save skills resolvers NLP model: ${e}`)
|
||||
reject()
|
||||
}
|
||||
|
||||
|
@ -297,7 +297,8 @@ server.init = async () => {
|
||||
// Load NLP models
|
||||
try {
|
||||
await Promise.all([
|
||||
nlu.loadResolversModel(join(process.cwd(), 'core/data/models/leon-resolvers-model.nlp')),
|
||||
nlu.loadGlobalResolversModel(join(process.cwd(), 'core/data/models/leon-global-resolvers-model.nlp')),
|
||||
nlu.loadSkillsResolversModel(join(process.cwd(), 'core/data/models/leon-skills-resolvers-model.nlp')),
|
||||
nlu.loadMainModel(join(process.cwd(), 'core/data/models/leon-main-model.nlp'))
|
||||
])
|
||||
} catch (e) {
|
||||
|
@ -38,7 +38,8 @@ class Nlu {
|
||||
constructor (brain) {
|
||||
this.brain = brain
|
||||
this.request = request
|
||||
this.resolversNlp = { }
|
||||
this.globalResolversNlp = { }
|
||||
this.skillsResolversNlp = { }
|
||||
this.mainNlp = { }
|
||||
this.ner = { }
|
||||
this.conv = new Conversation('conv0')
|
||||
@ -49,13 +50,13 @@ class Nlu {
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the resolvers NLP model from the latest training
|
||||
* Load the global resolvers NLP model from the latest training
|
||||
*/
|
||||
loadResolversModel (nlpModel) {
|
||||
loadGlobalResolversModel (nlpModel) {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
if (!fs.existsSync(nlpModel)) {
|
||||
log.title('NLU')
|
||||
reject({ type: 'warning', obj: new Error('The resolvers NLP model does not exist, please run: npm run train') })
|
||||
reject({ type: 'warning', obj: new Error('The global resolvers NLP model does not exist, please run: npm run train') })
|
||||
} else {
|
||||
log.title('NLU')
|
||||
|
||||
@ -65,12 +66,47 @@ class Nlu {
|
||||
container.use(Nlp)
|
||||
container.use(LangAll)
|
||||
|
||||
this.resolversNlp = container.get('nlp')
|
||||
this.globalResolversNlp = container.get('nlp')
|
||||
const nluManager = container.get('nlu-manager')
|
||||
nluManager.settings.spellCheck = true
|
||||
|
||||
await this.resolversNlp.load(nlpModel)
|
||||
log.success('Resolvers NLP model loaded')
|
||||
await this.globalResolversNlp.load(nlpModel)
|
||||
log.success('Global resolvers NLP model loaded')
|
||||
|
||||
resolve()
|
||||
} catch (err) {
|
||||
this.brain.talk(`${this.brain.wernicke('random_errors')}! ${this.brain.wernicke('errors', 'nlu', { '%error%': err.message })}.`)
|
||||
this.brain.socket.emit('is-typing', false)
|
||||
|
||||
reject({ type: 'error', obj: err })
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the skills resolvers NLP model from the latest training
|
||||
*/
|
||||
loadSkillsResolversModel (nlpModel) {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
if (!fs.existsSync(nlpModel)) {
|
||||
log.title('NLU')
|
||||
reject({ type: 'warning', obj: new Error('The skills resolvers NLP model does not exist, please run: npm run train') })
|
||||
} else {
|
||||
log.title('NLU')
|
||||
|
||||
try {
|
||||
const container = await containerBootstrap()
|
||||
|
||||
container.use(Nlp)
|
||||
container.use(LangAll)
|
||||
|
||||
this.skillsResolversNlp = container.get('nlp')
|
||||
const nluManager = container.get('nlu-manager')
|
||||
nluManager.settings.spellCheck = true
|
||||
|
||||
await this.skillsResolversNlp.load(nlpModel)
|
||||
log.success('Skills resolvers NLP model loaded')
|
||||
|
||||
resolve()
|
||||
} catch (err) {
|
||||
@ -127,7 +163,8 @@ class Nlu {
|
||||
* Check if NLP models exists
|
||||
*/
|
||||
hasNlpModels () {
|
||||
return Object.keys(this.resolversNlp).length > 0
|
||||
return Object.keys(this.globalResolversNlp).length > 0
|
||||
&& Object.keys(this.skillsResolversNlp).length > 0
|
||||
&& Object.keys(this.mainNlp).length > 0
|
||||
}
|
||||
|
||||
@ -235,35 +272,12 @@ class Nlu {
|
||||
hasMatchingEntity = this.nluResultObj
|
||||
.entities.filter(({ entity }) => expectedItemName === entity).length > 0
|
||||
} else if (expectedItemType.indexOf('resolver') !== -1) {
|
||||
const result = await this.resolversNlp.process(utterance)
|
||||
const { classifications } = result
|
||||
let { intent } = result
|
||||
|
||||
/**
|
||||
* Prioritize skill resolvers in the classification
|
||||
* to not overlap with resolvers from other skills
|
||||
*/
|
||||
if (this.conv.hasActiveContext()) {
|
||||
const classification = classifications.find(({ intent: newIntent, score: newScore }) => {
|
||||
const [, skillName] = newIntent.split('.')
|
||||
|
||||
if (expectedItemType === 'skill_resolver') {
|
||||
// Prioritize skill resolver intent
|
||||
if (newScore > 0.6) {
|
||||
return this.nluResultObj.classification.skill === skillName
|
||||
}
|
||||
}
|
||||
|
||||
if (expectedItemType === 'global_resolver') {
|
||||
// Use a global resolver if any
|
||||
return skillName === 'global'
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
// eslint-disable-next-line prefer-destructuring
|
||||
intent = classification?.intent
|
||||
const nlpObjs = {
|
||||
global_resolver: this.globalResolversNlp,
|
||||
skill_resolver: this.skillsResolversNlp
|
||||
}
|
||||
const result = await nlpObjs[expectedItemType].process(utterance)
|
||||
const { intent } = result
|
||||
|
||||
const resolveResolvers = (resolver, intent) => {
|
||||
const resolversPath = join(process.cwd(), 'core/data', this.brain.lang, 'global-resolvers')
|
||||
@ -286,7 +300,7 @@ class Nlu {
|
||||
log.title('NLU')
|
||||
log.success('Resolvers resolved:')
|
||||
this.nluResultObj.resolvers = resolveResolvers(expectedItemName, intent)
|
||||
this.nluResultObj.resolvers.forEach((resolver) => log.success(JSON.stringify(resolver)))
|
||||
this.nluResultObj.resolvers.forEach((resolver) => log.success(`${intent}: ${JSON.stringify(resolver)}`))
|
||||
hasMatchingResolver = this.nluResultObj.resolvers.length > 0
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user