JS bindings for loading model and shortlist files as bytes (#117)

* Bindings to load model and shortlist files as bytes
* Modified wasm test page for byte based loading of files
* Updates wasm README for byte loading based usage of TranslationModel
This commit is contained in:
abhi-agg 2021-04-29 12:04:04 +02:00 committed by GitHub
parent e5ec5bdd33
commit de0abfd795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 30 deletions

View File

@ -1,9 +1,19 @@
## Using Bergamot Translator in JavaScript
The example file `bergamot.html` in the folder `test_page` demonstrates how to use the bergamot translator in JavaScript via a `<script>` tag.
Please note that everything below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) were packaged in wasm binary (using the compile instructions given in the top level README).
### <a name="Pre-requisite"></a> Pre-requisite: Download files required for translation
### Using JS APIs
Please note that [Using JS APIs](#Using-JS-APIs) and [Demo](#Demo) section below assumes that the [bergamot project specific model files](https://github.com/mozilla-applied-ml/bergamot-models) are already downloaded and present in the `test_page` folder. If this is not done then use following instructions to do so:
```bash
cd test_page
mkdir models
git clone --depth 1 --branch main --single-branch https://github.com/mozilla-applied-ml/bergamot-models
cp -rf bergamot-models/prod/* models
gunzip models/*/*
```
### <a name="Using-JS-APIs"></a> Using JS APIs
```js
// The model configuration as YAML formatted string. For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
@ -34,13 +44,16 @@ request.delete();
input.delete();
```
### Demo (see everything in action)
### <a name="Demo"></a> Demo (see everything in action)
* Make sure that you followed [Pre-requisite](#Pre-requisite) instructions before moving forward.
* Start the test webserver (ensure you have the latest nodejs installed)
```bash
cd test_page
bash start_server.sh
```
* Open any of the browsers below
* Firefox Nightly +87: make sure the following prefs are on (about:config)
```

View File

@ -10,10 +10,27 @@
using namespace emscripten;
// Binding code
val getByteArrayView(marian::bergamot::AlignedMemory& alignedMemory) {
return val(typed_memory_view(alignedMemory.size(), alignedMemory.as<char>()));
}
EMSCRIPTEN_BINDINGS(aligned_memory) {
class_<marian::bergamot::AlignedMemory>("AlignedMemory")
.constructor<std::size_t, std::size_t>()
.function("size", &marian::bergamot::AlignedMemory::size)
.function("getByteArrayView", &getByteArrayView)
;
}
TranslationModel* TranslationModelFactory(const std::string &config,
marian::bergamot::AlignedMemory* modelMemory,
marian::bergamot::AlignedMemory* shortlistMemory) {
return new TranslationModel(config, std::move(*modelMemory), std::move(*shortlistMemory));
}
EMSCRIPTEN_BINDINGS(translation_model) {
class_<TranslationModel>("TranslationModel")
.constructor<std::string>()
.constructor(&TranslationModelFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translate)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported)
;

View File

@ -2,7 +2,7 @@
<html>
<head>
<link rel="icon" href="data:,">
<meta http-equiv="Content-Type" content="text/html;charset=ISO-8859-1">
<meta http-equiv="Content-Type" content="text/html;charset=UTF-8">
</head>
<style>
body, html, div {
@ -61,9 +61,27 @@ En consecuencia, durante el año 2011 se introdujeron 180 proyectos de ley que r
</div>
<script>
// This function downloads file from a url and returns the array buffer
const downloadAsArrayBuffer = async(url) => {
const response = await fetch(url);
if (!response.ok) {
throw Error(`HTTP ${response.status} - ${response.statusText}`);
}
return response.arrayBuffer();
}
var model, request, input = undefined;
const loadModel = (from, to) => {
// This function constructs the AlignedMemory from the array buffer and the alignment size
function constructAlignedMemoryFromBuffer(buffer, alignmentSize) {
var byteArray = new Int8Array(buffer);
console.debug("byteArray size: ", byteArray.byteLength);
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
const alignedByteArrayView = alignedMemory.getByteArrayView();
alignedByteArrayView.set(byteArray);
return alignedMemory;
}
var translationModel, request, input = undefined;
const constructTranslationModel = async (from, to) => {
const languagePair = `${from}${to}`;
@ -72,7 +90,7 @@ En consecuencia, durante el año 2011 se introdujeron 180 proyectos de ley que r
// Set the Model Configuration as YAML formatted string.
// For available configuration options, please check: https://marian-nmt.github.io/docs/cmd/marian-decoder/
const modelConfig = `models:
/*const modelConfig = `models:
- /${languagePair}/model.${languagePair}.intgemm.alphas.bin
vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
@ -93,22 +111,53 @@ shortlist:
- 50
- 50
`;
/*
This config is not valid anymore in new APIs
mini-batch: 32
maxi-batch: 100
maxi-batch-sort: src
*/
const modelConfigWithoutModelAndShortList = `vocabs:
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
- /${vocabLanguagePair}/vocab.${vocabLanguagePair}.spm
beam-size: 1
normalize: 1.0
word-penalty: 0
max-length-break: 128
mini-batch-words: 1024
workspace: 128
max-length-factor: 2.0
skip-cost: true
cpu-threads: 0
quiet: true
quiet-translation: true
`;
// TODO: Use in model config when wormhole is enabled:
// gemm-precision: int8shift
// TODO: Use in model config when loading of binary models is supported and we use model.intgemm.alphas.bin:
// gemm-precision: int8shiftAlphaAll
console.debug("modelConfig: ", modelConfig);
const modelFile = `${languagePair}/model.${languagePair}.intgemm.alphas.bin`;
console.debug("modelFile: ", modelFile);
const shortlistFile = `${languagePair}/lex.${languagePair}.s2t.bin`;
console.debug("shortlistFile: ", shortlistFile);
// Instantiate the TranslationModel
if (model) model.delete();
model = new Module.TranslationModel(modelConfig);
try {
// Download the files as buffers from the given urls
let start = Date.now();
const downloadedBuffers = await Promise.all([downloadAsArrayBuffer(modelFile), downloadAsArrayBuffer(shortlistFile)]);
const modelBuffer = downloadedBuffers[0];
const shortListBuffer = downloadedBuffers[1];
log(`${languagePair} file download took ${(Date.now() - start) / 1000} secs`);
// Construct AlignedMemory objects with downloaded buffers
var alignedModelMemory = constructAlignedMemoryFromBuffer(modelBuffer, 256);
var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 64);
// Instantiate the TranslationModel
if (translationModel) translationModel.delete();
console.debug("Creating TranslationModel with config:", modelConfigWithoutModelAndShortList);
translationModel = new Module.TranslationModel(modelConfigWithoutModelAndShortList, alignedModelMemory, alignedShortlistMemory);
} catch (error) {
console.error(error);
}
}
const translate = (paragraphs) => {
@ -127,16 +176,9 @@ maxi-batch-sort: src
})
// Access input (just for debugging)
console.log('Input size=', input.size());
/*
for (let i = 0; i < input.size(); i++) {
console.log(' val:' + input.get(i));
}
*/
// Translate the input; the result is a vector<TranslationResult>
let result = model.translate(input, request);
// Access original and translated text from each entry of vector<TranslationResult>
//console.log('Result size=', result.size(), ' - TimeDiff - ', (Date.now() - start)/1000);
let result = translationModel.translate(input, request);
const translatedParagraphs = [];
for (let i = 0; i < result.size(); i++) {
translatedParagraphs.push(result.get(i).getTranslatedText());
@ -147,14 +189,16 @@ maxi-batch-sort: src
return translatedParagraphs;
}
document.querySelector("#load").addEventListener("click", () => {
document.querySelector("#load").addEventListener("click", async() => {
document.querySelector("#load").disabled = true;
const lang = document.querySelector('input[name="modellang"]:checked').value;
const from = lang.substring(0, 2);
const to = lang.substring(2, 4);
let start = Date.now();
loadModel(from, to)
log(`model ${from}${to} loaded in ${(Date.now() - start) / 1000} secs`);
//log('Model Alignment:', model.isAlignmentSupported());
await constructTranslationModel(from, to);
log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
document.querySelector("#load").disabled = false;
//log('Model Alignment:', translationModel.isAlignmentSupported());
});
const translateCall = () => {