mirror of
https://github.com/browsermt/bergamot-translator.git
synced 2024-10-26 05:43:59 +03:00
JS: Using supervised QE models for available language pairs (#378)
* JS: Refactored model loading - Passing single vocab memory via JS * JS: Use supervised QE models when available * Ran clang format
This commit is contained in:
parent
2c0e65c2ec
commit
0a52a6d405
@ -45,11 +45,15 @@ std::vector<std::shared_ptr<AlignedMemory>> prepareVocabsSmartMemories(std::vect
|
||||
}
|
||||
|
||||
MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
|
||||
std::vector<AlignedMemory*> uniqueVocabsMemories) {
|
||||
std::vector<AlignedMemory*> uniqueVocabsMemories,
|
||||
AlignedMemory* qualityEstimatorMemory) {
|
||||
MemoryBundle memoryBundle;
|
||||
memoryBundle.model = std::move(*modelMemory);
|
||||
memoryBundle.shortlist = std::move(*shortlistMemory);
|
||||
memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories));
|
||||
if (qualityEstimatorMemory != nullptr) {
|
||||
memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory);
|
||||
}
|
||||
|
||||
return memoryBundle;
|
||||
}
|
||||
@ -57,9 +61,9 @@ MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shor
|
||||
// This allows only shared_ptrs to be operational in JavaScript, according to emscripten.
|
||||
// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers
|
||||
std::shared_ptr<TranslationModel> TranslationModelFactory(const std::string& config, AlignedMemory* model,
|
||||
AlignedMemory* shortlist,
|
||||
std::vector<AlignedMemory*> vocabs) {
|
||||
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs);
|
||||
AlignedMemory* shortlist, std::vector<AlignedMemory*> vocabs,
|
||||
AlignedMemory* qualityEstimator) {
|
||||
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator);
|
||||
return std::make_shared<TranslationModel>(config, std::move(memoryBundle));
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,14 @@ const MODEL_REGISTRY = "../models/registry.json";
|
||||
const MODEL_ROOT_URL = "../models/";
|
||||
const PIVOT_LANGUAGE = 'en';
|
||||
|
||||
// Information corresponding to each file type
|
||||
const fileInfo = [
|
||||
{"type": "model", "alignment": 256},
|
||||
{"type": "lex", "alignment": 64},
|
||||
{"type": "vocab", "alignment": 64},
|
||||
{"type": "qualityModel", "alignment": 64}
|
||||
];
|
||||
|
||||
const encoder = new TextEncoder(); // string to utf-8 converter
|
||||
const decoder = new TextDecoder(); // utf-8 to string converter
|
||||
|
||||
@ -169,12 +177,17 @@ const _downloadAsArrayBuffer = async(url) => {
|
||||
// Constructs and initializes the AlignedMemory from the array buffer and alignment size
|
||||
const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => {
|
||||
var byteArray = new Int8Array(buffer);
|
||||
log(`Constructing Aligned memory. Size: ${byteArray.byteLength} bytes, Alignment: ${alignmentSize}`);
|
||||
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
|
||||
log(`Aligned memory construction done`);
|
||||
const alignedByteArrayView = alignedMemory.getByteArrayView();
|
||||
alignedByteArrayView.set(byteArray);
|
||||
log(`Aligned memory initialized`);
|
||||
return alignedMemory;
|
||||
}
|
||||
|
||||
async function prepareAlignedMemory(file, languagePair) {
|
||||
const fileName = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair][file.type].name}`;
|
||||
const buffer = await _downloadAsArrayBuffer(fileName);
|
||||
const alignedMemory = await _prepareAlignedMemoryFromBuffer(buffer, file.alignment);
|
||||
log(`"${file.type}" aligned memory prepared. Size:${alignedMemory.size()} bytes, alignment:${file.alignment}`);
|
||||
return alignedMemory;
|
||||
}
|
||||
|
||||
@ -201,45 +214,26 @@ gemm-precision: int8shiftAlphaAll
|
||||
alignment: soft
|
||||
`;
|
||||
|
||||
const modelFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["model"].name}`;
|
||||
const shortlistFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["lex"].name}`;
|
||||
const vocabFiles = [`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`,
|
||||
`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`];
|
||||
const promises = [];
|
||||
fileInfo.filter(file => modelRegistry[languagePair].hasOwnProperty(file.type))
|
||||
.map((file) => {
|
||||
promises.push(prepareAlignedMemory(file, languagePair));
|
||||
});
|
||||
|
||||
const uniqueVocabFiles = new Set(vocabFiles);
|
||||
log(`modelFile: ${modelFile}\nshortlistFile: ${shortlistFile}\nNo. of unique vocabs: ${uniqueVocabFiles.size}`);
|
||||
uniqueVocabFiles.forEach(item => log(`unique vocabFile: ${item}`));
|
||||
|
||||
// Download the files as buffers from the given urls
|
||||
let start = Date.now();
|
||||
const downloadedBuffers = await Promise.all([_downloadAsArrayBuffer(modelFile), _downloadAsArrayBuffer(shortlistFile)]);
|
||||
const modelBuffer = downloadedBuffers[0];
|
||||
const shortListBuffer = downloadedBuffers[1];
|
||||
|
||||
const downloadedVocabBuffers = [];
|
||||
for (let item of uniqueVocabFiles.values()) {
|
||||
downloadedVocabBuffers.push(await _downloadAsArrayBuffer(item));
|
||||
}
|
||||
log(`Total Download time for all files of '${languagePair}': ${(Date.now() - start) / 1000} secs`);
|
||||
|
||||
// Construct AlignedMemory objects with downloaded buffers
|
||||
let constructedAlignedMemories = await Promise.all([_prepareAlignedMemoryFromBuffer(modelBuffer, 256),
|
||||
_prepareAlignedMemoryFromBuffer(shortListBuffer, 64)]);
|
||||
let alignedModelMemory = constructedAlignedMemories[0];
|
||||
let alignedShortlistMemory = constructedAlignedMemories[1];
|
||||
let alignedVocabsMemoryList = new Module.AlignedMemoryList;
|
||||
for(let item of downloadedVocabBuffers) {
|
||||
let alignedMemory = await _prepareAlignedMemoryFromBuffer(item, 64);
|
||||
alignedVocabsMemoryList.push_back(alignedMemory);
|
||||
}
|
||||
for (let vocabs=0; vocabs < alignedVocabsMemoryList.size(); vocabs++) {
|
||||
log(`Aligned vocab memory${vocabs+1} size: ${alignedVocabsMemoryList.get(vocabs).size()}`);
|
||||
}
|
||||
log(`Aligned model memory size: ${alignedModelMemory.size()}`);
|
||||
log(`Aligned shortlist memory size: ${alignedShortlistMemory.size()}`);
|
||||
const alignedMemories = await Promise.all(promises);
|
||||
|
||||
log(`Translation Model config: ${modelConfig}`);
|
||||
var translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
|
||||
log(`Aligned memory sizes: Model:${alignedMemories[0].size()} Shortlist:${alignedMemories[1].size()} Vocab:${alignedMemories[2].size()}`);
|
||||
const alignedVocabMemoryList = new Module.AlignedMemoryList();
|
||||
alignedVocabMemoryList.push_back(alignedMemories[2]);
|
||||
let translationModel;
|
||||
if (alignedMemories.length === fileInfo.length) {
|
||||
log(`QE:${alignedMemories[3].size()}`);
|
||||
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, alignedMemories[3]);
|
||||
}
|
||||
else {
|
||||
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, null);
|
||||
}
|
||||
languagePairToTranslationModels.set(languagePair, translationModel);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user