Refactor wasm bindings to use consistent interface names as in native (#195)

* Refactored wasm bindings code
 - Replaced TranslationModel, TranslationRequest and TranslationResult
    with Service, ResponseOptions and Response
 - Corresponding documentation changes
 - Names of the bindings files changed
 - Moved Vector<Response> definition in Response specific bindings
   file
This commit is contained in:
Abhishek Aggarwal 2021-06-15 16:02:14 +02:00 committed by GitHub
parent 4b014665ba
commit b00116cb94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 82 deletions

View File

@ -105,7 +105,7 @@ class Service {
/// recommended to work with futures and translate() API.
///
/// @param [in] source: rvalue reference of the string to be translated
/// @param [in] translationRequest: ResponseOptions indicating whether or not
/// @param [in] responseOptions: ResponseOptions indicating whether or not
/// to include some member in the Response, also specify any additional
/// configurable parameters.
std::vector<Response> translateMultiple(std::vector<std::string> &&source, ResponseOptions responseOptions);

View File

@ -1,7 +1,7 @@
add_executable(bergamot-translator-worker
bindings/TranslationModelBindings.cpp
bindings/TranslationRequestBindings.cpp
bindings/TranslationResultBindings.cpp
bindings/service_bindings.cpp
bindings/response_options_bindings.cpp
bindings/response_bindings.cpp
)
# Generate version file that can be included in the wasm artifacts

View File

@ -63,27 +63,27 @@ var alignedShortlistMemory = constructAlignedMemoryFromBuffer(shortListBuffer, 6
var alignedVocabsMemoryList = new Module.AlignedMemoryList;
downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64)));
// Instantiate the TranslationModel
const model = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
// Instantiate the Translation Service
const translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
// Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector<string>)
const request = new Module.TranslationRequest();
// Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector<string>)
const responseOptions = new Module.ResponseOptions();
const input = new Module.VectorString;
// Initialize the input
input.push_back("Hola"); input.push_back("Mundo");
// translate the input; the result is a vector<TranslationResult>
const result = model.translate(input, request);
// translate the input; the result is a vector<Response>
const result = translationService.translate(input, responseOptions);
// Print original and translated text from each entry of vector<TranslationResult>
// Print original and translated text from each entry of vector<Response>
for (let i = 0; i < result.size(); i++) {
console.log(' original=' + result.get(i).getOriginalText() + ', translation=' + result.get(i).getTranslatedText());
}
// Don't forget to clean up the instances
model.delete();
request.delete();
translationService.delete();
responseOptions.delete();
input.delete();
```

View File

@ -1,15 +0,0 @@
/*
* Bindings for TranslationRequest class
*
*/
#include <emscripten/bind.h>
#include "response_options.h"
typedef marian::bergamot::ResponseOptions TranslationRequest;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(translation_request) { class_<TranslationRequest>("TranslationRequest").constructor<>(); }

View File

@ -1,22 +0,0 @@
/*
* Bindings for TranslationResult class
*
*/
#include <emscripten/bind.h>
#include <vector>
#include "response.h"
typedef marian::bergamot::Response TranslationResult;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(translation_result) {
class_<TranslationResult>("TranslationResult")
.constructor<>()
.function("getOriginalText", &TranslationResult::getOriginalText)
.function("getTranslatedText", &TranslationResult::getTranslatedText);
}

View File

@ -0,0 +1,24 @@
/*
* Bindings for Response class
*
*/
#include <emscripten/bind.h>
#include <vector>
#include "response.h"
typedef marian::bergamot::Response Response;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(response) {
class_<Response>("Response")
.constructor<>()
.function("getOriginalText", &Response::getOriginalText)
.function("getTranslatedText", &Response::getTranslatedText);
register_vector<Response>("VectorResponse");
}

View File

@ -0,0 +1,15 @@
/*
* Bindings for ResponseOptions class
*
*/
#include <emscripten/bind.h>
#include "response_options.h"
typedef marian::bergamot::ResponseOptions ResponseOptions;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(response_options) { class_<ResponseOptions>("ResponseOptions").constructor<>(); }

View File

@ -1,18 +1,14 @@
/*
* TranslationModelBindings.cpp
*
* Bindings for TranslationModel class
* Bindings for Service class
*/
#include <emscripten/bind.h>
#include "response.h"
#include "service.h"
using namespace emscripten;
typedef marian::bergamot::Service TranslationModel;
typedef marian::bergamot::Response TranslationResult;
typedef marian::bergamot::Service Service;
typedef marian::bergamot::AlignedMemory AlignedMemory;
val getByteArrayView(AlignedMemory& alignedMemory) {
@ -29,7 +25,7 @@ EMSCRIPTEN_BINDINGS(aligned_memory) {
}
// When source and target vocab files are same, only one memory object is passed from JS to
// avoid allocating memory twice for the same file. However, the constructor of the TranslationModel
// avoid allocating memory twice for the same file. However, the constructor of the Service
// class still expects 2 entries in this case, where each entry has the shared ownership of the
// same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects
// for unique AlignedMemory objects passed from JS.
@ -56,21 +52,18 @@ marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, A
return memoryBundle;
}
TranslationModel* TranslationModelFactory(const std::string& config, AlignedMemory* modelMemory,
AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
return new TranslationModel(config,
std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
Service* ServiceFactory(const std::string& config, AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
return new Service(config, std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
}
EMSCRIPTEN_BINDINGS(translation_model) {
class_<TranslationModel>("TranslationModel")
.constructor(&TranslationModelFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translateMultiple)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported);
EMSCRIPTEN_BINDINGS(translation_service) {
class_<Service>("Service")
.constructor(&ServiceFactory, allow_raw_pointers())
.function("translate", &Service::translateMultiple)
.function("isAlignmentSupported", &Service::isAlignmentSupported);
// ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is
// translate. If and when async comes, we can be done with this inconsistency.
register_vector<std::string>("VectorString");
register_vector<TranslationResult>("VectorTranslationResult");
}

View File

@ -80,8 +80,8 @@ En consecuencia, durante el año 2011 se introdujeron 180 proyectos de ley que r
return alignedMemory;
}
var translationModel, request, input = undefined;
const constructTranslationModel = async (from, to) => {
var translationService, responseOptions, input = undefined;
const constructTranslationService = async (from, to) => {
const languagePair = `${from}${to}`;
@ -162,10 +162,10 @@ gemm-precision: int8shift
var alignedVocabsMemoryList = new Module.AlignedMemoryList;
downloadedVocabBuffers.forEach(item => alignedVocabsMemoryList.push_back(constructAlignedMemoryFromBuffer(item, 64)));
// Instantiate the TranslationModel
if (translationModel) translationModel.delete();
console.debug("Creating TranslationModel with config:", modelConfig);
translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
// Instantiate the Translation Service
if (translationService) translationService.delete();
console.debug("Creating Translation Service with config:", modelConfig);
translationService = new Module.Service(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
} catch (error) {
log(error);
}
@ -173,8 +173,8 @@ gemm-precision: int8shift
const translate = (paragraphs) => {
// Instantiate the arguments of translate() API i.e. TranslationRequest and input (vector<string>)
var request = new Module.TranslationRequest();
// Instantiate the arguments of translate() API i.e. ResponseOptions and input (vector<string>)
var responseOptions = new Module.ResponseOptions();
let input = new Module.VectorString;
// Initialize the input
@ -188,14 +188,14 @@ gemm-precision: int8shift
// Access input (just for debugging)
console.log('Input size=', input.size());
// Translate the input; the result is a vector<TranslationResult>
let result = translationModel.translate(input, request);
// Translate the input; the result is a vector<Response>
let result = translationService.translate(input, responseOptions);
const translatedParagraphs = [];
for (let i = 0; i < result.size(); i++) {
translatedParagraphs.push(result.get(i).getTranslatedText());
}
console.log({ translatedParagraphs });
request.delete();
responseOptions.delete();
input.delete();
return translatedParagraphs;
}
@ -206,10 +206,10 @@ gemm-precision: int8shift
const from = lang.substring(0, 2);
const to = lang.substring(2, 4);
let start = Date.now();
await constructTranslationModel(from, to);
log(`translation model ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
await constructTranslationService(from, to);
log(`translation service ${from}${to} construction took ${(Date.now() - start) / 1000} secs`);
document.querySelector("#load").disabled = false;
//log('Model Alignment:', translationModel.isAlignmentSupported());
//log('Model Alignment:', translationService.isAlignmentSupported());
});
const translateCall = () => {