JS: Refactoring wasm test page (#354)

* Free all the objects properly that were constructed for translation api
* Refactored pivot detection mechanism
This commit is contained in:
Abhishek Aggarwal 2022-02-17 14:16:26 +01:00 committed by GitHub
parent 9f55fb4756
commit 2844cedb0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -85,7 +85,8 @@ const constructTranslationService = async () => {
}
}
// Constructs translation model for the source and target language pair.
// Constructs translation model(s) for the source and target language pair (using
// pivoting if required).
const constructTranslationModel = async (from, to) => {
// Delete all previously constructed translation models and clear the map
languagePairToTranslationModels.forEach((value, key) => {
@ -94,62 +95,63 @@ const constructTranslationModel = async (from, to) => {
});
languagePairToTranslationModels.clear();
const languagePairs = _getLanguagePairs(from, to);
log(`Constructing translation model(s): ${languagePairs.toString()}`);
if (languagePairs.length == 2) {
// This implies pivoting is required => Construct 2 translation models
await Promise.all([_constructTranslationModelHelper(languagePairs[0]),
_constructTranslationModelHelper(languagePairs[1])]);
if (_isPivotingRequired(from, to)) {
// Pivoting requires 2 translation models
const languagePairSrcToPivot = _getLanguagePair(from, PIVOT_LANGUAGE);
const languagePairPivotToTarget = _getLanguagePair(PIVOT_LANGUAGE, to);
await Promise.all([_constructTranslationModelHelper(languagePairSrcToPivot),
_constructTranslationModelHelper(languagePairPivotToTarget)]);
}
else {
// This implies pivoting is not required => Construct 1 translation model
await _constructTranslationModelHelper(languagePairs[0]);
// Non-pivoting case requires only 1 translation model
await _constructTranslationModelHelper(_getLanguagePair(from, to));
}
}
// Translates text from source language to target language (via pivoting if necessary).
const translate = (from, to, input, translateOptions) => {
const languagePairs = _getLanguagePairs(from, to);
log(`Translating for language pair(s): '${languagePairs.toString()}'`);
let vectorResponseOptions, vectorSourceText, vectorResponse;
try {
// Prepare the arguments (vectorResponseOptions and vectorSourceText (vector<string>)) of Translation API and call it.
// Result is a vector<Response> where each of its item corresponds to one item of vectorSourceText in the same order.
vectorResponseOptions = _prepareResponseOptions(translateOptions);
vectorSourceText = _prepareSourceText(input);
// Each language pair requires a corresponding loaded translation model. Otherwise, it's an error.
let translationModels = _getLoadedTranslationModels(from, to);
if (translationModels.length != languagePairs.length) {
throw Error(`Insufficient no. of loaded translation models. Required:'${languagePairs.length}' Found:'${translationModels.length}'`);
if (_isPivotingRequired(from, to)) {
// Translate via pivoting
const translationModelSrcToPivot = _getLoadedTranslationModel(from, PIVOT_LANGUAGE);
const translationModelPivotToTarget = _getLoadedTranslationModel(PIVOT_LANGUAGE, to);
vectorResponse = translationService.translateViaPivoting(translationModelSrcToPivot,
translationModelPivotToTarget,
vectorSourceText,
vectorResponseOptions);
}
else {
// Translate without pivoting
const translationModel = _getLoadedTranslationModel(from, to);
vectorResponse = translationService.translate(translationModel, vectorSourceText, vectorResponseOptions);
}
// Parse all relevant information from vectorResponse
const listTranslatedText = _parseTranslatedText(vectorResponse);
const listSourceText = _parseSourceText(vectorResponse);
const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse);
const listSourceTextSentences = _parseSourceTextSentences(vectorResponse);
const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse);
log(`Source text: ${listSourceText}`);
log(`Translated text: ${listTranslatedText}`);
log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`);
log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`);
log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`);
return listTranslatedText;
} finally {
// Necessary clean up
if (vectorSourceText != null) vectorSourceText.delete();
if (vectorResponseOptions != null) vectorResponseOptions.delete();
if (vectorResponse != null) vectorResponse.delete();
}
// Prepare the arguments (vectorResponseOptions and vectorSourceText (vector<string>)) of Translation API and call it.
// Result is a vector<Response> where each of its item corresponds to one item of vectorSourceText in the same order.
const vectorResponseOptions = _prepareResponseOptions(translateOptions);
let vectorSourceText = _prepareSourceText(input);
let vectorResponse;
if (translationModels.length == 2) {
// This implies translation should be done via pivoting
vectorResponse = translationService.translateViaPivoting(translationModels[0], translationModels[1], vectorSourceText, vectorResponseOptions);
}
else {
// This implies translation should be done without pivoting
vectorResponse = translationService.translate(translationModels[0], vectorSourceText, vectorResponseOptions);
}
// Parse all relevant information from vectorResponse
const listTranslatedText = _parseTranslatedText(vectorResponse);
const listSourceText = _parseSourceText(vectorResponse);
const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse);
const listSourceTextSentences = _parseSourceTextSentences(vectorResponse);
const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse);
log(`Source text: ${listSourceText}`);
log(`Translated text: ${listTranslatedText}`);
log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`);
log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`);
log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`);
// Delete prepared SourceText to avoid memory leak
vectorSourceText.delete();
vectorResponseOptions.delete();
return listTranslatedText;
}
// Downloads file from a url and returns the array buffer
@ -174,6 +176,7 @@ const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => {
}
const _constructTranslationModelHelper = async (languagePair) => {
log(`Constructing translation model ${languagePair}`);
/*Set the Model Configuration as YAML formatted string.
For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
@ -237,35 +240,20 @@ alignment: soft
languagePairToTranslationModels.set(languagePair, translationModel);
}
const _isPivotingRequired = (lang1, lang2) => {
if ((lang1 === PIVOT_LANGUAGE) || (lang2 === PIVOT_LANGUAGE)) {
return false;
}
return true;
const _isPivotingRequired = (from, to) => {
return (from !== PIVOT_LANGUAGE) && (to !== PIVOT_LANGUAGE);
}
const _getLanguagePairs = (srcLang, tgtLang) => {
const languagePairs = [];
if (_isPivotingRequired(srcLang, tgtLang)) {
// Do not change the push order
languagePairs.push(`${srcLang}${PIVOT_LANGUAGE}`);
languagePairs.push(`${PIVOT_LANGUAGE}${tgtLang}`);
}
else {
languagePairs.push(`${srcLang}${tgtLang}`);
}
return languagePairs;
const _getLanguagePair = (srcLang, tgtLang) => {
return `${srcLang}${tgtLang}`;
}
const _getLoadedTranslationModels = (srcLang, tgtLang) => {
const languagePairs = _getLanguagePairs(srcLang, tgtLang);
const loadedTranslationModels = [];
for (const langPair of languagePairs) {
if (languagePairToTranslationModels.has(langPair)) {
loadedTranslationModels.push(languagePairToTranslationModels.get(langPair));
}
const _getLoadedTranslationModel = (srcLang, tgtLang) => {
const languagePair = _getLanguagePair(srcLang, tgtLang);
if (!languagePairToTranslationModels.has(languagePair)) {
throw Error(`Translation model '${languagePair}' not loaded`);
}
return loadedTranslationModels;
return languagePairToTranslationModels.get(languagePair);
}
const _parseTranslatedText = (vectorResponse) => {
@ -343,10 +331,18 @@ const _parseTranslatedTextSentenceQualityScores = (vectorResponse) => {
}
const _prepareResponseOptions = (translateOptions) => {
const vectorResponseOptions = new Module.VectorResponseOptions;
let vectorResponseOptions = new Module.VectorResponseOptions;
translateOptions.forEach(translateOption => {
vectorResponseOptions.push_back({qualityScores: translateOption["isQualityScores"], alignment: true, html: translateOption["isHtml"]});
vectorResponseOptions.push_back({
qualityScores: translateOption["isQualityScores"],
alignment: true,
html: translateOption["isHtml"]
});
});
if (vectorResponseOptions.size() == 0) {
vectorResponseOptions.delete();
throw Error(`No Translation Options provided`);
}
return vectorResponseOptions;
}
@ -359,6 +355,10 @@ const _prepareSourceText = (input) => {
}
vectorSourceText.push_back(paragraph.trim())
})
if (vectorSourceText.size() == 0) {
vectorSourceText.delete();
throw Error(`No text provided to translate`);
}
return vectorSourceText;
}