JS: Using supervised QE models for available language pairs (#378)

* JS: Refactored model loading
 - Passing single vocab memory via JS
* JS: Use supervised QE models when available
* Ran clang format
This commit is contained in:
Abhishek Aggarwal 2022-03-15 15:55:28 +01:00 committed by GitHub
parent 2c0e65c2ec
commit 0a52a6d405
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 43 deletions

View File

@ -45,11 +45,15 @@ std::vector<std::shared_ptr<AlignedMemory>> prepareVocabsSmartMemories(std::vect
}
MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
std::vector<AlignedMemory*> uniqueVocabsMemories,
AlignedMemory* qualityEstimatorMemory) {
MemoryBundle memoryBundle;
memoryBundle.model = std::move(*modelMemory);
memoryBundle.shortlist = std::move(*shortlistMemory);
memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories));
if (qualityEstimatorMemory != nullptr) {
memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory);
}
return memoryBundle;
}
@ -57,9 +61,9 @@ MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shor
// This allows only shared_ptrs to be operational in JavaScript, according to emscripten.
// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers
std::shared_ptr<TranslationModel> TranslationModelFactory(const std::string& config, AlignedMemory* model,
AlignedMemory* shortlist,
std::vector<AlignedMemory*> vocabs) {
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs);
AlignedMemory* shortlist, std::vector<AlignedMemory*> vocabs,
AlignedMemory* qualityEstimator) {
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator);
return std::make_shared<TranslationModel>(config, std::move(memoryBundle));
}

View File

@ -12,6 +12,14 @@ const MODEL_REGISTRY = "../models/registry.json";
const MODEL_ROOT_URL = "../models/";
const PIVOT_LANGUAGE = 'en';
// Information corresponding to each file type
const fileInfo = [
{"type": "model", "alignment": 256},
{"type": "lex", "alignment": 64},
{"type": "vocab", "alignment": 64},
{"type": "qualityModel", "alignment": 64}
];
const encoder = new TextEncoder(); // string to utf-8 converter
const decoder = new TextDecoder(); // utf-8 to string converter
@ -169,12 +177,17 @@ const _downloadAsArrayBuffer = async(url) => {
// Constructs and initializes the AlignedMemory from the array buffer and alignment size
const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => {
var byteArray = new Int8Array(buffer);
log(`Constructing Aligned memory. Size: ${byteArray.byteLength} bytes, Alignment: ${alignmentSize}`);
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
log(`Aligned memory construction done`);
const alignedByteArrayView = alignedMemory.getByteArrayView();
alignedByteArrayView.set(byteArray);
log(`Aligned memory initialized`);
return alignedMemory;
}
async function prepareAlignedMemory(file, languagePair) {
const fileName = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair][file.type].name}`;
const buffer = await _downloadAsArrayBuffer(fileName);
const alignedMemory = await _prepareAlignedMemoryFromBuffer(buffer, file.alignment);
log(`"${file.type}" aligned memory prepared. Size:${alignedMemory.size()} bytes, alignment:${file.alignment}`);
return alignedMemory;
}
@ -201,45 +214,26 @@ gemm-precision: int8shiftAlphaAll
alignment: soft
`;
const modelFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["model"].name}`;
const shortlistFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["lex"].name}`;
const vocabFiles = [`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`,
`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`];
const promises = [];
fileInfo.filter(file => modelRegistry[languagePair].hasOwnProperty(file.type))
.map((file) => {
promises.push(prepareAlignedMemory(file, languagePair));
});
const uniqueVocabFiles = new Set(vocabFiles);
log(`modelFile: ${modelFile}\nshortlistFile: ${shortlistFile}\nNo. of unique vocabs: ${uniqueVocabFiles.size}`);
uniqueVocabFiles.forEach(item => log(`unique vocabFile: ${item}`));
// Download the files as buffers from the given urls
let start = Date.now();
const downloadedBuffers = await Promise.all([_downloadAsArrayBuffer(modelFile), _downloadAsArrayBuffer(shortlistFile)]);
const modelBuffer = downloadedBuffers[0];
const shortListBuffer = downloadedBuffers[1];
const downloadedVocabBuffers = [];
for (let item of uniqueVocabFiles.values()) {
downloadedVocabBuffers.push(await _downloadAsArrayBuffer(item));
}
log(`Total Download time for all files of '${languagePair}': ${(Date.now() - start) / 1000} secs`);
// Construct AlignedMemory objects with downloaded buffers
let constructedAlignedMemories = await Promise.all([_prepareAlignedMemoryFromBuffer(modelBuffer, 256),
_prepareAlignedMemoryFromBuffer(shortListBuffer, 64)]);
let alignedModelMemory = constructedAlignedMemories[0];
let alignedShortlistMemory = constructedAlignedMemories[1];
let alignedVocabsMemoryList = new Module.AlignedMemoryList;
for(let item of downloadedVocabBuffers) {
let alignedMemory = await _prepareAlignedMemoryFromBuffer(item, 64);
alignedVocabsMemoryList.push_back(alignedMemory);
}
for (let vocabs=0; vocabs < alignedVocabsMemoryList.size(); vocabs++) {
log(`Aligned vocab memory${vocabs+1} size: ${alignedVocabsMemoryList.get(vocabs).size()}`);
}
log(`Aligned model memory size: ${alignedModelMemory.size()}`);
log(`Aligned shortlist memory size: ${alignedShortlistMemory.size()}`);
const alignedMemories = await Promise.all(promises);
log(`Translation Model config: ${modelConfig}`);
var translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
log(`Aligned memory sizes: Model:${alignedMemories[0].size()} Shortlist:${alignedMemories[1].size()} Vocab:${alignedMemories[2].size()}`);
const alignedVocabMemoryList = new Module.AlignedMemoryList();
alignedVocabMemoryList.push_back(alignedMemories[2]);
let translationModel;
if (alignedMemories.length === fileInfo.length) {
log(`QE:${alignedMemories[3].size()}`);
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, alignedMemories[3]);
}
else {
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, null);
}
languagePairToTranslationModels.set(languagePair, translationModel);
}