From b00116cb9485689fd43521ad3cbcbe939f650f84 Mon Sep 17 00:00:00 2001 From: Abhishek Aggarwal <66322306+abhi-agg@users.noreply.github.com> Date: Tue, 15 Jun 2021 16:02:14 +0200 Subject: [PATCH] Refactor wasm bindings to use consistent interface names as in native (#195) * Refactored wasm bindings code - Replaced TranslationModel, TranslationRequest and TranslationResult with Service, ResponseOptions and Response - Corresponding documentation changes - Names of the bindings files changed - Moved Vector definition in Response specific bindings file --- src/translator/service.h | 2 +- wasm/CMakeLists.txt | 6 ++-- wasm/README.md | 18 ++++++------ wasm/bindings/TranslationRequestBindings.cpp | 15 ---------- wasm/bindings/TranslationResultBindings.cpp | 22 -------------- wasm/bindings/response_bindings.cpp | 24 +++++++++++++++ wasm/bindings/response_options_bindings.cpp | 15 ++++++++++ ...ModelBindings.cpp => service_bindings.cpp} | 29 +++++++------------ wasm/test_page/bergamot.html | 28 +++++++++--------- 9 files changed, 77 insertions(+), 82 deletions(-) delete mode 100644 wasm/bindings/TranslationRequestBindings.cpp delete mode 100644 wasm/bindings/TranslationResultBindings.cpp create mode 100644 wasm/bindings/response_bindings.cpp create mode 100644 wasm/bindings/response_options_bindings.cpp rename wasm/bindings/{TranslationModelBindings.cpp => service_bindings.cpp} (69%) diff --git a/src/translator/service.h b/src/translator/service.h index 26ea831..6c7ea9a 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -105,7 +105,7 @@ class Service { /// recommended to work with futures and translate() API. /// /// @param [in] source: rvalue reference of the string to be translated - /// @param [in] translationRequest: ResponseOptions indicating whether or not + /// @param [in] responseOptions: ResponseOptions indicating whether or not /// to include some member in the Response, also specify any additional /// configurable parameters. std::vector translateMultiple(std::vector &&source, ResponseOptions responseOptions); diff --git a/wasm/CMakeLists.txt b/wasm/CMakeLists.txt index 602dc5d..1580def 100644 --- a/wasm/CMakeLists.txt +++ b/wasm/CMakeLists.txt @@ -1,7 +1,7 @@ add_executable(bergamot-translator-worker - bindings/TranslationModelBindings.cpp - bindings/TranslationRequestBindings.cpp - bindings/TranslationResultBindings.cpp + bindings/service_bindings.cpp + bindings/response_options_bindings.cpp + bindings/response_bindings.cpp ) # Generate version file that can be included in the wasm artifacts diff --git a/wasm/README.md b/wasm/README.md index 01fea18..728b0a3 100644 --- a/wasm/README.md +++ b/wasm/README.md @@ -63,27 +63,27 @@ var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 6 var alignedVocabsMemoryList = new Module.AlignedMemoryList; downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64))); -// Instantiate the TranslationModel -const model = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); +// Instantiate the Translation Service +const translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); -// Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector) -const request = new Module.TranslationRequest(); +// Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector) +const responseOptions = new Module.ResponseOptions(); const input = new Module.VectorString; // Initialize the input input.push_back("Hola"); input.push_back("Mundo"); -// translate the input; the result is a vector -const result = model.translate(input, request); +// translate the input; the result is a vector +const result = translationService.translate(input, responseOptions); -// Print original and translated text from each entry of vector +// Print original and translated text from each entry of vector for (let i = 0; i < result.size(); i++) { console.log(' original=' + result.get(i).getOriginalText() + ', translation=' + result.get(i).getTranslatedText()); } // Don't forget to clean up the instances -model.delete(); -request.delete(); +translationService.delete(); +responseOptions.delete(); input.delete(); ``` diff --git a/wasm/bindings/TranslationRequestBindings.cpp b/wasm/bindings/TranslationRequestBindings.cpp deleted file mode 100644 index 42ac6c6..0000000 --- a/wasm/bindings/TranslationRequestBindings.cpp +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Bindings for TranslationRequest class - * - */ - -#include - -#include "response_options.h" - -typedef marian::bergamot::ResponseOptions TranslationRequest; - -using namespace emscripten; - -// Binding code -EMSCRIPTEN_BINDINGS(translation_request) { class_("TranslationRequest").constructor<>(); } diff --git a/wasm/bindings/TranslationResultBindings.cpp b/wasm/bindings/TranslationResultBindings.cpp deleted file mode 100644 index f02bef9..0000000 --- a/wasm/bindings/TranslationResultBindings.cpp +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Bindings for TranslationResult class - * - */ - -#include - -#include - -#include "response.h" - -typedef marian::bergamot::Response TranslationResult; - -using namespace emscripten; - -// Binding code -EMSCRIPTEN_BINDINGS(translation_result) { - class_("TranslationResult") - .constructor<>() - .function("getOriginalText", &TranslationResult::getOriginalText) - .function("getTranslatedText", &TranslationResult::getTranslatedText); -} diff --git a/wasm/bindings/response_bindings.cpp b/wasm/bindings/response_bindings.cpp new file mode 100644 index 0000000..5619119 --- /dev/null +++ b/wasm/bindings/response_bindings.cpp @@ -0,0 +1,24 @@ +/* + * Bindings for Response class + * + */ + +#include + +#include + +#include "response.h" + +typedef marian::bergamot::Response Response; + +using namespace emscripten; + +// Binding code +EMSCRIPTEN_BINDINGS(response) { + class_("Response") + .constructor<>() + .function("getOriginalText", &Response::getOriginalText) + .function("getTranslatedText", &Response::getTranslatedText); + + register_vector("VectorResponse"); +} diff --git a/wasm/bindings/response_options_bindings.cpp b/wasm/bindings/response_options_bindings.cpp new file mode 100644 index 0000000..e2bf8e1 --- /dev/null +++ b/wasm/bindings/response_options_bindings.cpp @@ -0,0 +1,15 @@ +/* + * Bindings for ResponseOptions class + * + */ + +#include + +#include "response_options.h" + +typedef marian::bergamot::ResponseOptions ResponseOptions; + +using namespace emscripten; + +// Binding code +EMSCRIPTEN_BINDINGS(response_options) { class_("ResponseOptions").constructor<>(); } diff --git a/wasm/bindings/TranslationModelBindings.cpp b/wasm/bindings/service_bindings.cpp similarity index 69% rename from wasm/bindings/TranslationModelBindings.cpp rename to wasm/bindings/service_bindings.cpp index 64203a1..416a318 100644 --- a/wasm/bindings/TranslationModelBindings.cpp +++ b/wasm/bindings/service_bindings.cpp @@ -1,18 +1,14 @@ /* - * TranslationModelBindings.cpp - * - * Bindings for TranslationModel class + * Bindings for Service class */ #include -#include "response.h" #include "service.h" using namespace emscripten; -typedef marian::bergamot::Service TranslationModel; -typedef marian::bergamot::Response TranslationResult; +typedef marian::bergamot::Service Service; typedef marian::bergamot::AlignedMemory AlignedMemory; val getByteArrayView(AlignedMemory& alignedMemory) { @@ -29,7 +25,7 @@ EMSCRIPTEN_BINDINGS(aligned_memory) { } // When source and target vocab files are same, only one memory object is passed from JS to -// avoid allocating memory twice for the same file. However, the constructor of the TranslationModel +// avoid allocating memory twice for the same file. However, the constructor of the Service // class still expects 2 entries in this case, where each entry has the shared ownership of the // same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects // for unique AlignedMemory objects passed from JS. @@ -56,21 +52,18 @@ marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, A return memoryBundle; } -TranslationModel* TranslationModelFactory(const std::string& config, AlignedMemory* modelMemory, - AlignedMemory* shortlistMemory, - std::vector uniqueVocabsMemories) { - return new TranslationModel(config, - std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories))); +Service* ServiceFactory(const std::string& config, AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, + std::vector uniqueVocabsMemories) { + return new Service(config, std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories))); } -EMSCRIPTEN_BINDINGS(translation_model) { - class_("TranslationModel") - .constructor(&TranslationModelFactory, allow_raw_pointers()) - .function("translate", &TranslationModel::translateMultiple) - .function("isAlignmentSupported", &TranslationModel::isAlignmentSupported); +EMSCRIPTEN_BINDINGS(translation_service) { + class_("Service") + .constructor(&ServiceFactory, allow_raw_pointers()) + .function("translate", &Service::translateMultiple) + .function("isAlignmentSupported", &Service::isAlignmentSupported); // ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is // translate. If and when async comes, we can be done with this inconsistency. register_vector("VectorString"); - register_vector("VectorTranslationResult"); } diff --git a/wasm/test_page/bergamot.html b/wasm/test_page/bergamot.html index d150af6..c69c950 100644 --- a/wasm/test_page/bergamot.html +++ b/wasm/test_page/bergamot.html @@ -80,8 +80,8 @@ En consecuencia, durante el año 2011 se introdujeron 180 proyectos de ley que r return alignedMemory; } - var translationModel, request, input = undefined; - const constructTranslationModel = async (from, to) => { + var translationService, responseOptions, input = undefined; + const constructTranslationService = async (from, to) => { const languagePair = `${from}${to}`; @@ -162,10 +162,10 @@ gemm-precision: int8shift var alignedVocabsMemoryList = new Module.AlignedMemoryList; downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64))); - // Instantiate the TranslationModel - if (translationModel) translationModel.delete(); - console.debug("Creating TranslationModel with config:", modelConfig); - translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); + // Instantiate the Translation Service + if (translationService) translationService.delete(); + console.debug("Creating Translation Service with config:", modelConfig); + translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); } catch (error) { log(error); } @@ -173,8 +173,8 @@ gemm-precision: int8shift const translate = (paragraphs) => { - // Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector) - var request = new Module.TranslationRequest(); + // Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector) + var responseOptions = new Module.ResponseOptions(); let input = new Module.VectorString; // Initialize the input @@ -188,14 +188,14 @@ gemm-precision: int8shift // Access input (just for debugging) console.log('Input size=', input.size()); - // Translate the input; the result is a vector - let result = translationModel.translate(input, request); + // Translate the input; the result is a vector + let result = translationService.translate(input, responseOptions); const translatedParagraphs = []; for (let i = 0; i < result.size(); i++) { translatedParagraphs.push(result.get(i).getTranslatedText()); } console.log({ translatedParagraphs }); - request.delete(); + responseOptions.delete(); input.delete(); return translatedParagraphs; } @@ -206,10 +206,10 @@ gemm-precision: int8shift const from = lang.substring(0, 2); const to = lang.substring(2, 4); let start = Date.now(); - await constructTranslationModel(from, to); - log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`); + await constructTranslationService(from, to); + log(`translation service ${from}${to} construction took ${(Date.now() - start) / 1000} secs`); document.querySelector("#load").disabled = false; - //log('Model Alignment:', translationModel.isAlignmentSupported()); + //log('Model Alignment:', translationService.isAlignmentSupported()); }); const translateCall = () => {