Wasm/JS: Pivot translation API JS binding and test page update (#327)

This commit is contained in:
Abhishek Aggarwal 2022-02-02 17:01:23 +01:00 committed by GitHub
parent 19ae519c63
commit d95b014562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 77 deletions

View File

@ -78,7 +78,8 @@ EMSCRIPTEN_BINDINGS(blocking_service_config) {
EMSCRIPTEN_BINDINGS(blocking_service) {
class_<BlockingService>("BlockingService")
.constructor<BlockingService::Config>()
.function("translate", &BlockingService::translateMultiple);
.function("translate", &BlockingService::translateMultiple)
.function("translateViaPivoting", &BlockingService::pivotMultiple);
register_vector<std::string>("VectorString");
}

View File

@ -1,10 +1,12 @@
// All variables specific to translation service
var translationService = undefined;
// A map of language-pair to TranslationModel object
var languagePairToTranslationModels = new Map();
const BERGAMOT_TRANSLATOR_MODULE = "bergamot-translator-worker.js";
const MODEL_REGISTRY = "modelRegistry.js";
const PIVOT_LANGUAGE = 'en';
const encoder = new TextEncoder(); // string to utf-8 converter
const decoder = new TextDecoder(); // utf-8 to string converter
@ -82,7 +84,7 @@ const constructTranslationService = async () => {
}
}
// Constructs a translation model object for the source and target language pair
// Constructs translation model for the source and target language pair.
const constructTranslationModel = async (from, to) => {
// Delete all previously constructed translation models and clear the map
languagePairToTranslationModels.forEach((value, key) => {
@ -91,32 +93,61 @@ const constructTranslationModel = async (from, to) => {
});
languagePairToTranslationModels.clear();
// If none of the languages is English then construct multiple models with
// English as a pivot language.
if (from !== 'en' && to !== 'en') {
log(`Constructing model '${from}${to}' via pivoting: '${from}en' and 'en${to}'`);
await Promise.all([_constructTranslationModelInvolvingEnglish(from, 'en'),
_constructTranslationModelInvolvingEnglish('en', to)]);
const languagePairs = _getLanguagePairs(from, to);
log(`Constructing translation model(s): ${languagePairs.toString()}`);
if (languagePairs.length == 2) {
// This implies pivoting is required => Construct 2 translation models
await Promise.all([_constructTranslationModelHelper(languagePairs[0]),
_constructTranslationModelHelper(languagePairs[1])]);
}
else {
log(`Constructing model '${from}${to}'`);
await _constructTranslationModelInvolvingEnglish(from, to);
// This implies pivoting is not required => Construct 1 translation model
await _constructTranslationModelHelper(languagePairs[0]);
}
}
// Translates text from source language to target language.
// Translates text from source language to target language (via pivoting if necessary).
const translate = (from, to, input) => {
// If none of the languages is English then perform translation with
// English as a pivot language.
if (from !== 'en' && to !== 'en') {
log(`Translating '${from}${to}' via pivoting: '${from}en' -> 'en${to}'`);
const translatedTextInEnglish = _translateInvolvingEnglish(from, 'en', input);
return _translateInvolvingEnglish('en', to, translatedTextInEnglish);
const languagePairs = _getLanguagePairs(from, to);
log(`Translating for language pair(s): '${languagePairs.toString()}'`);
// Each language pair requires a corresponding loaded translation model. Otherwise, it's an error.
let translationModels = _getLoadedTranslationModels(from, to);
if (translationModels.length != languagePairs.length) {
throw Error(`Insufficient no. of loaded translation models. Required:'${languagePairs.length}' Found:'${translationModels.length}'`);
}
// Prepare the arguments (ResponseOptions and vectorSourceText (vector<string>)) of Translation API and call it.
// Result is a vector<Response> where each of its item corresponds to one item of vectorSourceText in the same order.
const responseOptions = _prepareResponseOptions();
let vectorSourceText = _prepareSourceText(input);
let vectorResponse;
if (translationModels.length == 2) {
// This implies translation should be done via pivoting
vectorResponse = translationService.translateViaPivoting(translationModels[0], translationModels[1], vectorSourceText, responseOptions);
}
else {
log(`Translating '${from}${to}'`);
return _translateInvolvingEnglish(from, to, input);
// This implies translation should be done without pivoting
vectorResponse = translationService.translate(translationModels[0], vectorSourceText, responseOptions);
}
// Parse all relevant information from vectorResponse
const listTranslatedText = _parseTranslatedText(vectorResponse);
const listSourceText = _parseSourceText(vectorResponse);
const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse);
const listSourceTextSentences = _parseSourceTextSentences(vectorResponse);
const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse);
log(`Source text: ${listSourceText}`);
log(`Translated text: ${listTranslatedText}`);
log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`);
log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`);
log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`);
// Delete prepared SourceText to avoid memory leak
vectorSourceText.delete();
return listTranslatedText;
}
// Downloads file from a url and returns the array buffer
@ -140,38 +171,13 @@ const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => {
return alignedMemory;
}
const _constructTranslationModelInvolvingEnglish = async (from, to) => {
const languagePair = `${from}${to}`;
const _constructTranslationModelHelper = async (languagePair) => {
/*Set the Model Configuration as YAML formatted string.
For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
Vocab files are re-used in both translation directions
const vocabLanguagePair = from === "en" ? `${to}${from}` : languagePair;
const modelConfig = `models:
- /${languagePair}/model.${languagePair}.intgemm.alphas.bin
vocabs:
- /${languagePair}/vocab.${vocabLanguagePair}.spm
- /${languagePair}/vocab.${vocabLanguagePair}.spm
beam-size: 1
normalize: 1.0
word-penalty: 0
max-length-break: 128
mini-batch-words: 1024
workspace: 128
max-length-factor: 2.0
skip-cost: true
cpu-threads: 0
quiet: true
quiet-translation: true
shortlist:
- /${languagePair}/lex.${languagePair}.s2t
- 50
- 50
`;
*/
// TODO: gemm-precision: int8shiftAlphaAll (for the models that support this)
// DONOT CHANGE THE SPACES BETWEEN EACH ENTRY OF CONFIG
Vocab files are re-used in both translation directions.
DO NOT CHANGE THE SPACES BETWEEN EACH ENTRY OF CONFIG
*/
const modelConfig = `beam-size: 1
normalize: 1.0
word-penalty: 0
@ -229,38 +235,35 @@ alignment: soft
languagePairToTranslationModels.set(languagePair, translationModel);
}
const _translateInvolvingEnglish = (from, to, input) => {
const languagePair = `${from}${to}`;
if (!languagePairToTranslationModels.has(languagePair)) {
throw Error(`Please load translation model '${languagePair}' before translating`);
const _isPivotingRequired = (lang1, lang2) => {
if ((lang1 === PIVOT_LANGUAGE) || (lang2 === PIVOT_LANGUAGE)) {
return false;
}
translationModel = languagePairToTranslationModels.get(languagePair);
return true;
}
// Prepare the arguments of translate() API i.e. ResponseOptions and vectorSourceText (i.e. a vector<string>)
const responseOptions = _prepareResponseOptions();
let vectorSourceText = _prepareSourceText(input);
const _getLanguagePairs = (srcLang, tgtLang) => {
const languagePairs = [];
if (_isPivotingRequired(srcLang, tgtLang)) {
// Do not change the push order
languagePairs.push(`${srcLang}${PIVOT_LANGUAGE}`);
languagePairs.push(`${PIVOT_LANGUAGE}${tgtLang}`);
}
else {
languagePairs.push(`${srcLang}${tgtLang}`);
}
return languagePairs;
}
// Call translate() API; result is vector<Response> where every item of vector<Response> corresponds
// to an item of vectorSourceText in the same order
const vectorResponse = translationService.translate(translationModel, vectorSourceText, responseOptions);
// Parse all relevant information from vectorResponse
const listTranslatedText = _parseTranslatedText(vectorResponse);
const listSourceText = _parseSourceText(vectorResponse);
const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse);
const listSourceTextSentences = _parseSourceTextSentences(vectorResponse);
const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse);
log(`Source text: ${listSourceText}`);
log(`Translated text: ${listTranslatedText}`);
log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`);
log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`);
log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`);
// Delete prepared SourceText to avoid memory leak
vectorSourceText.delete();
return listTranslatedText;
const _getLoadedTranslationModels = (srcLang, tgtLang) => {
const languagePairs = _getLanguagePairs(srcLang, tgtLang);
const loadedTranslationModels = [];
for (const langPair of languagePairs) {
if (languagePairToTranslationModels.has(langPair)) {
loadedTranslationModels.push(languagePairToTranslationModels.get(langPair));
}
}
return loadedTranslationModels;
}
const _parseTranslatedText = (vectorResponse) => {