diff --git a/wasm/test_page/bergamot.html b/wasm/test_page/bergamot.html index 95ae325..d150af6 100644 --- a/wasm/test_page/bergamot.html +++ b/wasm/test_page/bergamot.html @@ -113,10 +113,7 @@ shortlist: `; */ -const modelConfigWithoutModelAndShortList = `vocabs: - - /${languagePair}/vocab.${vocabLanguagePair}.spm - - /${languagePair}/vocab.${vocabLanguagePair}.spm -beam-size: 1 +const modelConfig = `beam-size: 1 normalize: 1.0 word-penalty: 0 max-length-break: 128 @@ -136,9 +133,15 @@ gemm-precision: int8shift // gemm-precision: int8shiftAlphaAll const modelFile = `models/${languagePair}/model.${languagePair}.intgemm.alphas.bin`; - console.debug("modelFile: ", modelFile); const shortlistFile = `models/${languagePair}/lex.50.50.${languagePair}.s2t.bin`; + const vocabFiles = [`models/${languagePair}/vocab.${vocabLanguagePair}.spm`, + `models/${languagePair}/vocab.${vocabLanguagePair}.spm`]; + + const uniqueVocabFiles = new Set(vocabFiles); + console.debug("modelFile: ", modelFile); console.debug("shortlistFile: ", shortlistFile); + console.debug("No. of unique vocabs: ", uniqueVocabFiles.size); + uniqueVocabFiles.forEach(item => console.debug("unique vocabFile: ", item)); try { // Download the files as buffers from the given urls @@ -146,16 +149,23 @@ gemm-precision: int8shift const downloadedBuffers = await Promise.all([downloadAsArrayBuffer(modelFile), downloadAsArrayBuffer(shortlistFile)]); const modelBuffer = downloadedBuffers[0]; const shortListBuffer = downloadedBuffers[1]; + + const downloadedVocabBuffers = []; + for (let item of uniqueVocabFiles.values()) { + downloadedVocabBuffers.push(await downloadAsArrayBuffer(item)); + } log(`${languagePair} file download took ${(Date.now() - start) / 1000} secs`); // Construct AlignedMemory objects with downloaded buffers var alignedModelMemory = constructAlignedMemoryFromBuffer(modelBuffer, 256); var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 64); + var alignedVocabsMemoryList = new Module.AlignedMemoryList; + downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64))); // Instantiate the TranslationModel if (translationModel) translationModel.delete(); - console.debug("Creating TranslationModel with config:", modelConfigWithoutModelAndShortList); - translationModel = new Module.TranslationModel(modelConfigWithoutModelAndShortList, alignedModelMemory, alignedShortlistMemory); + console.debug("Creating TranslationModel with config:", modelConfig); + translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); } catch (error) { log(error); }