1
1
mirror of https://github.com/leon-ai/leon.git synced 2025-01-02 13:43:51 +03:00
leon/scripts/train/train.js

141 lines
4.4 KiB
JavaScript
Raw Normal View History

import { containerBootstrap } from '@nlpjs/core-loader'
import { Nlp } from '@nlpjs/nlp'
import { LangAll } from '@nlpjs/lang-all'
import dotenv from 'dotenv'
import { LogHelper } from '@/helpers/log-helper'
import { LangHelper } from '@/helpers/lang-helper'
import trainGlobalResolvers from './train-resolvers-model/train-global-resolvers'
import trainSkillsResolvers from './train-resolvers-model/train-skills-resolvers'
import trainGlobalEntities from './train-main-model/train-global-entities'
import trainSkillsActions from './train-main-model/train-skills-actions'
dotenv.config()
/**
* Training utterance samples script
*
* npm run train [en or fr]
*/
2022-09-03 14:12:41 +03:00
export default () =>
new Promise(async (resolve, reject) => {
const globalResolversModelFileName =
'core/data/models/leon-global-resolvers-model.nlp'
const skillsResolversModelFileName =
'core/data/models/leon-skills-resolvers-model.nlp'
const mainModelFileName = 'core/data/models/leon-main-model.nlp'
try {
2022-09-03 14:12:41 +03:00
/**
* Global resolvers NLP model configuration
*/
const globalResolversContainer = await containerBootstrap()
globalResolversContainer.use(Nlp)
globalResolversContainer.use(LangAll)
const globalResolversNlp = globalResolversContainer.get('nlp')
const globalResolversNluManager =
globalResolversContainer.get('nlu-manager')
globalResolversNluManager.settings.log = false
globalResolversNluManager.settings.trainByDomain = false
globalResolversNlp.settings.modelFileName = globalResolversModelFileName
globalResolversNlp.settings.threshold = 0.8
/**
* Skills resolvers NLP model configuration
*/
const skillsResolversContainer = await containerBootstrap()
skillsResolversContainer.use(Nlp)
skillsResolversContainer.use(LangAll)
const skillsResolversNlp = skillsResolversContainer.get('nlp')
const skillsResolversNluManager =
skillsResolversContainer.get('nlu-manager')
skillsResolversNluManager.settings.log = false
skillsResolversNluManager.settings.trainByDomain = true
skillsResolversNlp.settings.modelFileName = skillsResolversModelFileName
skillsResolversNlp.settings.threshold = 0.8
/**
* Main NLP model configuration
*/
const mainContainer = await containerBootstrap()
mainContainer.use(Nlp)
mainContainer.use(LangAll)
const mainNlp = mainContainer.get('nlp')
const mainNluManager = mainContainer.get('nlu-manager')
// const mainSlotManager = container.get('SlotManager')
mainNluManager.settings.log = false
mainNluManager.settings.trainByDomain = true
// mainSlotManager.settings.
mainNlp.settings.forceNER = true // https://github.com/axa-group/nlp.js/blob/master/examples/17-ner-nlg/index.js
// mainNlp.settings.nlu = { useNoneFeature: true }
mainNlp.settings.calculateSentiment = true
mainNlp.settings.modelFileName = mainModelFileName
mainNlp.settings.threshold = 0.8
/**
* Training phases
*/
const shortLangs = LangHelper.getShortCodes()
2022-09-03 14:12:41 +03:00
for (let h = 0; h < shortLangs.length; h += 1) {
const lang = shortLangs[h]
globalResolversNlp.addLanguage(lang)
await trainGlobalResolvers(lang, globalResolversNlp)
skillsResolversNlp.addLanguage(lang)
await trainSkillsResolvers(lang, skillsResolversNlp)
mainNlp.addLanguage(lang)
await trainGlobalEntities(lang, mainNlp)
await trainSkillsActions(lang, mainNlp)
}
try {
await globalResolversNlp.train()
LogHelper.success(
2022-09-03 14:12:41 +03:00
`Global resolvers NLP model saved in ${globalResolversModelFileName}`
)
resolve()
} catch (e) {
LogHelper.error(`Failed to save global resolvers NLP model: ${e}`)
2022-09-03 14:12:41 +03:00
reject()
}
try {
await skillsResolversNlp.train()
LogHelper.success(
2022-09-03 14:12:41 +03:00
`Skills resolvers NLP model saved in ${skillsResolversModelFileName}`
)
resolve()
} catch (e) {
LogHelper.error(`Failed to save skills resolvers NLP model: ${e}`)
2022-09-03 14:12:41 +03:00
reject()
}
try {
await mainNlp.train()
LogHelper.success(`Main NLP model saved in ${mainModelFileName}`)
2022-09-03 14:12:41 +03:00
resolve()
} catch (e) {
LogHelper.error(`Failed to save main NLP model: ${e}`)
2022-09-03 14:12:41 +03:00
reject()
}
} catch (e) {
LogHelper.error(e.message)
2022-09-03 14:12:41 +03:00
reject(e)
}
2022-09-03 14:12:41 +03:00
})