Refactor wasm bindings to use consistent interface names as in native (#195)

* Refactored wasm bindings code
 - Replaced TranslationModel, TranslationRequest and TranslationResult
    with Service, ResponseOptions and Response
 - Corresponding documentation changes
 - Names of the bindings files changed
 - Moved Vector<Response> definition in Response specific bindings
   file
This commit is contained in:
Abhishek Aggarwal 2021-06-15 16:02:14 +02:00 committed by GitHub
parent 4b014665ba
commit b00116cb94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 82 deletions

View File

@ -105,7 +105,7 @@ class Service {
/// recommended to work with futures and translate() API. /// recommended to work with futures and translate() API.
/// ///
/// @param [in] source: rvalue reference of the string to be translated /// @param [in] source: rvalue reference of the string to be translated
/// @param [in] translationRequest: ResponseOptions indicating whether or not /// @param [in] responseOptions: ResponseOptions indicating whether or not
/// to include some member in the Response, also specify any additional /// to include some member in the Response, also specify any additional
/// configurable parameters. /// configurable parameters.
std::vector<Response> translateMultiple(std::vector<std::string> &&source, ResponseOptions responseOptions); std::vector<Response> translateMultiple(std::vector<std::string> &&source, ResponseOptions responseOptions);

View File

@ -1,7 +1,7 @@
add_executable(bergamot-translator-worker add_executable(bergamot-translator-worker
bindings/TranslationModelBindings.cpp bindings/service_bindings.cpp
bindings/TranslationRequestBindings.cpp bindings/response_options_bindings.cpp
bindings/TranslationResultBindings.cpp bindings/response_bindings.cpp
) )
# Generate version file that can be included in the wasm artifacts # Generate version file that can be included in the wasm artifacts

View File

@ -63,27 +63,27 @@ var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 6
var alignedVocabsMemoryList = new Module.AlignedMemoryList; var alignedVocabsMemoryList = new Module.AlignedMemoryList;
downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64))); downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64)));
// Instantiate the TranslationModel // Instantiate the Translation Service
const model = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); const translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
// Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector<string>) // Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector<string>)
const request = new Module.TranslationRequest(); const responseOptions = new Module.ResponseOptions();
const input = new Module.VectorString; const input = new Module.VectorString;
// Initialize the input // Initialize the input
input.push_back("Hola"); input.push_back("Mundo"); input.push_back("Hola"); input.push_back("Mundo");
// translate the input; the result is a vector<TranslationResult> // translate the input; the result is a vector<Response>
const result = model.translate(input, request); const result = translationService.translate(input, responseOptions);
// Print original and translated text from each entry of vector<TranslationResult> // Print original and translated text from each entry of vector<Response>
for (let i = 0; i < result.size(); i++) { for (let i = 0; i < result.size(); i++) {
console.log(' original=' + result.get(i).getOriginalText() + ', translation=' + result.get(i).getTranslatedText()); console.log(' original=' + result.get(i).getOriginalText() + ', translation=' + result.get(i).getTranslatedText());
} }
// Don't forget to clean up the instances // Don't forget to clean up the instances
model.delete(); translationService.delete();
request.delete(); responseOptions.delete();
input.delete(); input.delete();
``` ```

View File

@ -1,15 +0,0 @@
/*
* Bindings for TranslationRequest class
*
*/
#include <emscripten/bind.h>
#include "response_options.h"
typedef marian::bergamot::ResponseOptions TranslationRequest;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(translation_request) { class_<TranslationRequest>("TranslationRequest").constructor<>(); }

View File

@ -1,22 +0,0 @@
/*
* Bindings for TranslationResult class
*
*/
#include <emscripten/bind.h>
#include <vector>
#include "response.h"
typedef marian::bergamot::Response TranslationResult;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(translation_result) {
class_<TranslationResult>("TranslationResult")
.constructor<>()
.function("getOriginalText", &TranslationResult::getOriginalText)
.function("getTranslatedText", &TranslationResult::getTranslatedText);
}

View File

@ -0,0 +1,24 @@
/*
* Bindings for Response class
*
*/
#include <emscripten/bind.h>
#include <vector>
#include "response.h"
typedef marian::bergamot::Response Response;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(response) {
class_<Response>("Response")
.constructor<>()
.function("getOriginalText", &Response::getOriginalText)
.function("getTranslatedText", &Response::getTranslatedText);
register_vector<Response>("VectorResponse");
}

View File

@ -0,0 +1,15 @@
/*
* Bindings for ResponseOptions class
*
*/
#include <emscripten/bind.h>
#include "response_options.h"
typedef marian::bergamot::ResponseOptions ResponseOptions;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(response_options) { class_<ResponseOptions>("ResponseOptions").constructor<>(); }

View File

@ -1,18 +1,14 @@
/* /*
* TranslationModelBindings.cpp * Bindings for Service class
*
* Bindings for TranslationModel class
*/ */
#include <emscripten/bind.h> #include <emscripten/bind.h>
#include "response.h"
#include "service.h" #include "service.h"
using namespace emscripten; using namespace emscripten;
typedef marian::bergamot::Service TranslationModel; typedef marian::bergamot::Service Service;
typedef marian::bergamot::Response TranslationResult;
typedef marian::bergamot::AlignedMemory AlignedMemory; typedef marian::bergamot::AlignedMemory AlignedMemory;
val getByteArrayView(AlignedMemory& alignedMemory) { val getByteArrayView(AlignedMemory& alignedMemory) {
@ -29,7 +25,7 @@ EMSCRIPTEN_BINDINGS(aligned_memory) {
} }
// When source and target vocab files are same, only one memory object is passed from JS to // When source and target vocab files are same, only one memory object is passed from JS to
// avoid allocating memory twice for the same file. However, the constructor of the TranslationModel // avoid allocating memory twice for the same file. However, the constructor of the Service
// class still expects 2 entries in this case, where each entry has the shared ownership of the // class still expects 2 entries in this case, where each entry has the shared ownership of the
// same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects // same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects
// for unique AlignedMemory objects passed from JS. // for unique AlignedMemory objects passed from JS.
@ -56,21 +52,18 @@ marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, A
return memoryBundle; return memoryBundle;
} }
TranslationModel* TranslationModelFactory(const std::string& config, AlignedMemory* modelMemory, Service* ServiceFactory(const std::string& config, AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
AlignedMemory* shortlistMemory, std::vector<AlignedMemory*> uniqueVocabsMemories) {
std::vector<AlignedMemory*> uniqueVocabsMemories) { return new Service(config, std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
return new TranslationModel(config,
std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
} }
EMSCRIPTEN_BINDINGS(translation_model) { EMSCRIPTEN_BINDINGS(translation_service) {
class_<TranslationModel>("TranslationModel") class_<Service>("Service")
.constructor(&TranslationModelFactory, allow_raw_pointers()) .constructor(&ServiceFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translateMultiple) .function("translate", &Service::translateMultiple)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported); .function("isAlignmentSupported", &Service::isAlignmentSupported);
// ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is // ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is
// translate. If and when async comes, we can be done with this inconsistency. // translate. If and when async comes, we can be done with this inconsistency.
register_vector<std::string>("VectorString"); register_vector<std::string>("VectorString");
register_vector<TranslationResult>("VectorTranslationResult");
} }

View File

@ -80,8 +80,8 @@ En consecuencia, durante el año 2011 se introdujeron 180 proyectos de ley que r
return alignedMemory; return alignedMemory;
} }
var translationModel, request, input = undefined; var translationService, responseOptions, input = undefined;
const constructTranslationModel = async (from, to) => { const constructTranslationService = async (from, to) => {
const languagePair = `${from}${to}`; const languagePair = `${from}${to}`;
@ -162,10 +162,10 @@ gemm-precision: int8shift
var alignedVocabsMemoryList = new Module.AlignedMemoryList; var alignedVocabsMemoryList = new Module.AlignedMemoryList;
downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64))); downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64)));
// Instantiate the TranslationModel // Instantiate the Translation Service
if (translationModel) translationModel.delete(); if (translationService) translationService.delete();
console.debug("Creating TranslationModel with config:", modelConfig); console.debug("Creating Translation Service with config:", modelConfig);
translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
} catch (error) { } catch (error) {
log(error); log(error);
} }
@ -173,8 +173,8 @@ gemm-precision: int8shift
const translate = (paragraphs) => { const translate = (paragraphs) => {
// Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector<string>) // Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector<string>)
var request = new Module.TranslationRequest(); var responseOptions = new Module.ResponseOptions();
let input = new Module.VectorString; let input = new Module.VectorString;
// Initialize the input // Initialize the input
@ -188,14 +188,14 @@ gemm-precision: int8shift
// Access input (just for debugging) // Access input (just for debugging)
console.log('Input size=', input.size()); console.log('Input size=', input.size());
// Translate the input; the result is a vector<TranslationResult> // Translate the input; the result is a vector<Response>
let result = translationModel.translate(input, request); let result = translationService.translate(input, responseOptions);
const translatedParagraphs = []; const translatedParagraphs = [];
for (let i = 0; i < result.size(); i++) { for (let i = 0; i < result.size(); i++) {
translatedParagraphs.push(result.get(i).getTranslatedText()); translatedParagraphs.push(result.get(i).getTranslatedText());
} }
console.log({ translatedParagraphs }); console.log({ translatedParagraphs });
request.delete(); responseOptions.delete();
input.delete(); input.delete();
return translatedParagraphs; return translatedParagraphs;
} }
@ -206,10 +206,10 @@ gemm-precision: int8shift
const from = lang.substring(0, 2); const from = lang.substring(0, 2);
const to = lang.substring(2, 4); const to = lang.substring(2, 4);
let start = Date.now(); let start = Date.now();
await constructTranslationModel(from, to); await constructTranslationService(from, to);
log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`); log(`translation service ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
document.querySelector("#load").disabled = false; document.querySelector("#load").disabled = false;
//log('Model Alignment:', translationModel.isAlignmentSupported()); //log('Model Alignment:', translationService.isAlignmentSupported());
}); });
const translateCall = () => { const translateCall = () => {