diff --git a/.github/workflows/native-full_marian-mac.yml b/.github/workflows/native-full_marian-mac.yml index 44f1568..a2b62d5 100644 --- a/.github/workflows/native-full_marian-mac.yml +++ b/.github/workflows/native-full_marian-mac.yml @@ -45,10 +45,9 @@ jobs: working-directory: build run: make -j2 - # Removing unit-tests, taken care of in browsermt/marian-dev - # - name: Run unit tests - # - working-directory: build - # - run: make test + - name: Run unit tests + working-directory: build + run: make test - name: Print versions working-directory: build diff --git a/.github/workflows/native-full_marian-ubuntu.yml b/.github/workflows/native-full_marian-ubuntu.yml index 6ab8ea6..e67c366 100644 --- a/.github/workflows/native-full_marian-ubuntu.yml +++ b/.github/workflows/native-full_marian-ubuntu.yml @@ -103,13 +103,11 @@ jobs: working-directory: build run: make -j2 - # Removing unit-tests, taken care of in browsermt/marian-dev - # TODO: add a flag to CMake to compile unit tests only on CPU - # - name: Run unit tests - # working-directory: build - # run: make test - # # GitHub-hosted VMs do not have GPUs, so can not be run in CUDA builds - # if: matrix.gpu == false + - name: Run unit tests + working-directory: build + run: make test + # GitHub-hosted VMs do not have GPUs, so can not be run in CUDA builds + if: matrix.gpu == false - name: Print versions working-directory: build diff --git a/CMakeLists.txt b/CMakeLists.txt index 6417bd6..206a5b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,8 @@ include(CMakeDependentOption) # Project specific cmake options option(COMPILE_WASM "Compile for WASM" OFF) option(USE_WASM_COMPATIBLE_SOURCES "Use wasm compatible sources" ON) +option(COMPILE_TESTS "Compile bergamot-tests" OFF) + SET(PACKAGE_DIR "" CACHE STRING "Directory including all the files to be packaged (pre-loaded) in wasm builds") # Set 3rd party submodule specific cmake options for this project @@ -56,6 +58,11 @@ if(COMPILE_WASM) list(APPEND WASM_COMPILE_FLAGS -Wno-error=pthreads-mem-growth) endif(COMPILE_WASM) +# Needs to be enabled before including the folder containing tests (src/tests) +if(COMPILE_TESTS) + enable_testing() +endif(COMPILE_TESTS) + add_subdirectory(3rd_party) add_subdirectory(src) @@ -64,3 +71,4 @@ if(COMPILE_WASM) else() add_subdirectory(app) endif(COMPILE_WASM) + diff --git a/app/service-cli-bytearray.cpp b/app/service-cli-bytearray.cpp index d967810..cb3b17f 100644 --- a/app/service-cli-bytearray.cpp +++ b/app/service-cli-bytearray.cpp @@ -28,7 +28,7 @@ int main(int argc, char *argv[]) { std::future responseFuture = service.translate(std::move(input)); responseFuture.wait(); Response response = responseFuture.get(); - std::cout << response.translation() << std::endl; + std::cout << response.target.text << std::endl; // Clear the memory used for the byte array free(model_bytes); // Ideally, this should be done after the translation model has been gracefully shut down. diff --git a/app/service-cli.cpp b/app/service-cli.cpp index 2bb825c..6ed4d81 100644 --- a/app/service-cli.cpp +++ b/app/service-cli.cpp @@ -25,7 +25,56 @@ int main(int argc, char *argv[]) { std::future responseFuture = service.translate(std::move(input)); responseFuture.wait(); Response response = responseFuture.get(); - std::cout << response.translation() << std::endl; + + std::cout << "[original]: " << response.source.text << '\n'; + std::cout << "[translated]: " << response.target.text << '\n'; + for (int sentenceIdx = 0; sentenceIdx < response.size(); sentenceIdx++) { + std::cout << " [src Sentence]: " << response.source.sentence(sentenceIdx) + << '\n'; + std::cout << " [tgt Sentence]: " << response.target.sentence(sentenceIdx) + << '\n'; + std::cout << "Alignments" << '\n'; + typedef std::pair Point; + + // Initialize a point vector. + std::vector> aggregate( + response.source.numWords(sentenceIdx)); + + // Handle alignments + auto &alignments = response.alignments[sentenceIdx]; + for (auto &p : alignments) { + aggregate[p.src].emplace_back(p.tgt, p.prob); + } + + for (size_t src = 0; src < aggregate.size(); src++) { + std::cout << response.source.word(sentenceIdx, src) << ": "; + for (auto &p : aggregate[src]) { + std::cout << response.target.word(sentenceIdx, p.first) << "(" + << p.second << ") "; + } + std::cout << '\n'; + } + + // Handle quality. + auto &quality = response.qualityScores[sentenceIdx]; + std::cout << "Quality: whole(" << quality.sequence + << "), tokens below:" << '\n'; + size_t wordIdx = 0; + bool first = true; + for (auto &p : quality.word) { + if (first) { + first = false; + } else { + std::cout << " "; + } + std::cout << response.target.word(sentenceIdx, wordIdx) << "(" << p + << ")"; + wordIdx++; + } + std::cout << '\n'; + } + std::cout << "--------------------------\n"; + std::cout << '\n'; return 0; } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 27fecc4..c2d62ef 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1 +1,7 @@ -add_subdirectory(translator) \ No newline at end of file +add_subdirectory(translator) + +if(COMPILE_TESTS) + # Catch currently comes from marian sources. + add_subdirectory(tests) +endif(COMPILE_TESTS) + diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt new file mode 100644 index 0000000..5c1bc00 --- /dev/null +++ b/src/tests/CMakeLists.txt @@ -0,0 +1,22 @@ +# Unit tests +set(UNIT_TESTS + annotation_tests +) + +foreach(test ${UNIT_TESTS}) + add_executable("run_${test}" run_tests.cpp "${test}.cpp") + target_include_directories("run_${test}" PRIVATE ${CATCH_INCLUDE_DIR} "${CMAKE_SOURCE_DIR}/src") + + if(CUDA_FOUND) + target_link_libraries("run_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS} Catch bergamot-translator) + else(CUDA_FOUND) + target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch bergamot-translator) + endif(CUDA_FOUND) + + if(msvc) + # disable c4305: truncation from 'double' to '_ty' + target_compile_options("run_${test}" public /wd4305) + endif(msvc) + + add_test(NAME ${test} COMMAND "run_${test}") +endforeach(test) diff --git a/src/tests/annotation_tests.cpp b/src/tests/annotation_tests.cpp new file mode 100644 index 0000000..2284011 --- /dev/null +++ b/src/tests/annotation_tests.cpp @@ -0,0 +1,73 @@ +#include "catch.hpp" +#include "translator/sentence_ranges.h" +#include +#include + +using namespace marian::bergamot; + +TEST_CASE("Test Annotation API with random sentences") { + /// Objective here is to test insertion for sentences, and that whatever comes + /// out adheres to the way it was inserted. Towards this, we keep externally + /// which sentence went in where and try to use accessor methods on + /// AnnotatedText to check if what we have as ground-truth by construction is + /// consistent with what is returned. + size_t sentences = 20; + size_t maxWords = 40; + + std::mt19937 randomIntGen_; + randomIntGen_.seed(42); + + AnnotatedText testAnnotation; + std::vector> sentenceWords; + std::vector Words; + + for (size_t idx = 0; idx < sentences; idx++) { + if (idx != 0) + testAnnotation.text += "\n"; + + Words.clear(); + size_t words = randomIntGen_() % maxWords + 1; + Words.reserve(words); + for (size_t idw = 0; idw < words; idw++) { + size_t before = testAnnotation.text.size(); + std::string word = std::to_string(idx) + "-" + std::to_string(idw); + testAnnotation.text += word; + if (idw != 0) + testAnnotation.text += " "; + Words.push_back((ByteRange){before, before + word.size() - 1}); + } + // std::cout << std::endl; + + sentenceWords.push_back(Words); + } + + // std::cout << "Inserting words:" << std::endl; + std::vector> byteRanges; + for (auto &sentence : sentenceWords) { + std::vector wordByteRanges; + for (auto &word : sentence) { + marian::string_view wordView(&testAnnotation.text[word.begin], + word.end - word.begin); + wordByteRanges.push_back(wordView); + // std::cout << std::string(wordView) << " "; + } + testAnnotation.addSentence(wordByteRanges); + byteRanges.push_back(wordByteRanges); + // std::cout << std::endl; + } + + // std::cout << "From container: " << std::endl; + for (int idx = 0; idx < sentenceWords.size(); idx++) { + for (int idw = 0; idw < sentenceWords[idx].size(); idw++) { + ByteRange expected = sentenceWords[idx][idw]; + ByteRange obtained = testAnnotation.wordAsByteRange(idx, idw); + // std::cout << std::string(testAnnotation.word(idx, idw)) << " "; + CHECK(expected.begin == obtained.begin); + CHECK(expected.end == obtained.end); + + std::string expected_string = std::string(byteRanges[idx][idw]); + CHECK(expected_string == std::string(testAnnotation.word(idx, idw))); + } + // std::cout << std::endl; + } +} diff --git a/src/tests/run_tests.cpp b/src/tests/run_tests.cpp new file mode 100644 index 0000000..0c7c351 --- /dev/null +++ b/src/tests/run_tests.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" diff --git a/src/translator/TranslationModel.cpp b/src/translator/TranslationModel.cpp index 72de208..13f9495 100644 --- a/src/translator/TranslationModel.cpp +++ b/src/translator/TranslationModel.cpp @@ -32,24 +32,19 @@ TranslationModel::translate(std::vector &&texts, intermediate.wait(); auto marianResponse(std::move(intermediate.get())); - // This mess because marian::string_view != std::string_view - std::string source, translation; - marian::bergamot::Response::SentenceMappings mSentenceMappings; - marianResponse.move(source, translation, mSentenceMappings); - - // Convert to UnifiedAPI::TranslationResult TranslationResult::SentenceMappings sentenceMappings; - for (auto &p : mSentenceMappings) { - std::string_view src(p.first.data(), p.first.size()), - tgt(p.second.data(), p.second.size()); - sentenceMappings.emplace_back(src, tgt); + for (size_t idx = 0; idx < marianResponse.size(); idx++) { + marian::string_view src = marianResponse.source.sentence(idx); + marian::string_view tgt = marianResponse.target.sentence(idx); + sentenceMappings.emplace_back(std::string_view(src.data(), src.size()), + std::string_view(tgt.data(), tgt.size())); } // In place construction. translationResults.emplace_back( - std::move(source), // &&marianResponse.source_ - std::move(translation), // &&marianResponse.translation_ - std::move(sentenceMappings) // &&sentenceMappings + std::move(marianResponse.source.text), // &&marianResponse.source_ + std::move(marianResponse.target.text), // &&marianResponse.translation_ + std::move(sentenceMappings) // &&sentenceMappings ); } diff --git a/src/translator/request.cpp b/src/translator/request.cpp index 42dfb35..b6d2438 100644 --- a/src/translator/request.cpp +++ b/src/translator/request.cpp @@ -12,12 +12,10 @@ namespace bergamot { // ----------------------------------------------------------------- Request::Request(size_t Id, size_t lineNumberBegin, - std::vector> &vocabs, std::string &&source, - Segments &&segments, SentenceRanges &&sourceRanges, - std::promise responsePromise) + std::vector> &vocabs, AnnotatedText &&source, + Segments &&segments, std::promise responsePromise) : Id_(Id), lineNumberBegin_(lineNumberBegin), vocabs_(&vocabs), source_(std::move(source)), segments_(std::move(segments)), - sourceRanges_(std::move(sourceRanges)), response_(std::move(responsePromise)) { counter_ = segments_.size(); @@ -48,8 +46,7 @@ void Request::processHistory(size_t index, Ptr history) { void Request::completeRequest() { // Request no longer needs to hold the content, can transfer it to // Response. - Response response(std::move(source_), std::move(sourceRanges_), - std::move(histories_), *vocabs_); + Response response(std::move(source_), std::move(histories_), *vocabs_); response_.set_value(std::move(response)); } diff --git a/src/translator/request.h b/src/translator/request.h index 3909019..605dea7 100644 --- a/src/translator/request.h +++ b/src/translator/request.h @@ -1,9 +1,9 @@ // // Defines: // -// Request: holds the input blob of a text, Segments (vector) which are +// Request: holds the input text of a text, Segments (vector) which are // to go to the batching mechanism and alignments between the processed -// segments and the input blob (sourceTokenRanges). In addition, Request takes +// segments and the input text (sourceTokenRanges). In addition, Request takes // care of the barrier which fires when all the Segments in a request are done // translating by the workers (BatchTranslator). // TODO(jerinphilip): Extend Request with notions of Priority (sequence, @@ -36,9 +36,8 @@ namespace bergamot { class Request { public: Request(size_t Id, size_t lineNumberBegin, - std::vector> &vocabs_, std::string &&source, - Segments &&segments, SentenceRanges &&sourceTokenRanges, - std::promise responsePromise); + std::vector> &vocabs_, AnnotatedText &&source, + Segments &&segments, std::promise responsePromise); // Obtain the count of tokens in the segment correponding to index. Used to // insert sentence from multiple requests into the corresponding size bucket. @@ -77,9 +76,8 @@ private: // string_views of the text corresponding to these words, pointing to // sequences in source_. histories_ is a buffer which eventually stores the // translations of each segment in the corresponding index. - std::string source_; + AnnotatedText source_; Segments segments_; - SentenceRanges sourceRanges_; std::vector> histories_; // Members above are moved into newly constructed Response on completion diff --git a/src/translator/response.cpp b/src/translator/response.cpp index b731755..faa42da 100644 --- a/src/translator/response.cpp +++ b/src/translator/response.cpp @@ -1,49 +1,25 @@ #include "response.h" -#include "sentence_ranges.h" #include "common/logging.h" #include "data/alignment.h" +#include "sentence_ranges.h" #include namespace marian { namespace bergamot { -Response::Response(std::string &&source, SentenceRanges &&sourceRanges, - Histories &&histories, std::vector> &vocabs) - : source_(std::move(source)), sourceRanges_(std::move(sourceRanges)), - histories_(std::move(histories)), vocabs_(&vocabs) {} - -void Response::move(std::string &source, std::string &translation, - SentenceMappings &sentenceMappings) { - - // Construct required stuff first. - constructTranslation(); - constructSentenceMappings(sentenceMappings); - - // Move content out. - source = std::move(source_); - translation = std::move(translation_); - - // The above assignment expects source, target be moved. - // which makes the following invalid, hence required to be cleared. - sourceRanges_.clear(); - targetRanges_.clear(); - histories_.clear(); -} - -void Response::constructTranslation() { - if (translationConstructed_) { - return; - } - +Response::Response(AnnotatedText &&source, Histories &&histories, + std::vector> &vocabs) + : source(std::move(source)), histories_(std::move(histories)) { // Reserving length at least as much as source_ seems like a reasonable thing // to do to avoid reallocations. - translation_.reserve(source_.size()); + target.text.reserve(source.text.size()); // In a first step, the decoded units (individual senteneces) are compiled // into a huge string. This is done by computing indices first and appending // to the string as each sentences are decoded. std::vector> translationRanges; + std::vector sentenceBegins; size_t offset{0}; bool first{true}; @@ -54,44 +30,76 @@ void Response::constructTranslation() { Result result = onebest[0]; // Expecting only one result; Words words = std::get<0>(result); - auto targetVocab = vocabs_->back(); - std::string decoded = targetVocab->decode(words); + auto targetVocab = vocabs.back(); + + std::string decoded; + std::vector targetMappings; + targetVocab->decodeWithByteRanges(words, decoded, targetMappings); + if (first) { first = false; } else { - translation_ += " "; + target.text += " "; ++offset; } - translation_ += decoded; - translationRanges.emplace_back(offset, decoded.size()); + sentenceBegins.push_back(translationRanges.size()); + target.text += decoded; + auto decodedStringBeginMarker = targetMappings.front().begin(); + for (auto &sview : targetMappings) { + size_t startIdx = offset + sview.begin() - decodedStringBeginMarker; + translationRanges.emplace_back(startIdx, startIdx + sview.size()); + } + offset += decoded.size(); + + // Alignments + // TODO(jerinphilip): The following double conversion might not be + // necessary. Hard alignment can directly be exported, but this would mean + // WASM bindings for a structure deep within marian source. + auto hyp = std::get<1>(result); + auto softAlignment = hyp->tracebackAlignment(); + auto hardAlignment = data::ConvertSoftAlignToHardAlign( + softAlignment, /*threshold=*/0.2f); // TODO(jerinphilip): Make this a + // configurable parameter. + + Alignment unified_alignment; + for (auto &p : hardAlignment) { + unified_alignment.emplace_back((Point){p.srcPos, p.tgtPos, p.prob}); + } + + alignments.push_back(std::move(unified_alignment)); + + // Quality scores: Sequence level is obtained as normalized path scores. + // Word level using hypothesis traceback. These are most-likely logprobs. + auto normalizedPathScore = std::get<2>(result); + auto wordQualities = hyp->tracebackWordScores(); + wordQualities.pop_back(); + qualityScores.push_back((Quality){normalizedPathScore, wordQualities}); } - // Once the entire string is constructed, there are no further possibility of - // reallocation in the string's storage, the indices are converted into - // string_views. + // Once we have the indices in translation (which might be resized a few + // times) ready, we can prepare and store the string_view as annotations + // instead. This is accomplished by iterating over available sentences using + // sentenceBegin and using addSentence(...) API from Annotation. - for (auto &range : translationRanges) { - // TODO(@jerinphilip): Currently considers target tokens as whole text. - // Needs to be further enhanced in marian-dev to extract alignments. + for (size_t i = 1; i <= sentenceBegins.size(); i++) { std::vector targetMappings; + size_t begin = sentenceBegins[i - 1]; + size_t safe_end = (i == sentenceBegins.size()) ? translationRanges.size() + : sentenceBegins[i]; - const char *begin = &translation_[range.first]; - targetMappings.emplace_back(begin, range.second); - targetRanges_.addSentence(targetMappings); - } + for (size_t idx = begin; idx < safe_end; idx++) { + auto &p = translationRanges[idx]; + size_t begin_idx = p.first; + size_t end_idx = p.second; - translationConstructed_ = true; -} + const char *data = &target.text[begin_idx]; + size_t size = end_idx - begin_idx; + targetMappings.emplace_back(data, size); + } -void Response::constructSentenceMappings( - Response::SentenceMappings &sentenceMappings) { - - for (size_t i = 0; i < sourceRanges_.numSentences(); i++) { - string_view src = sourceRanges_.sentence(i); - string_view tgt = targetRanges_.sentence(i); - sentenceMappings.emplace_back(src, tgt); + target.addSentence(targetMappings); } } } // namespace bergamot diff --git a/src/translator/response.h b/src/translator/response.h index 17fee05..385735c 100644 --- a/src/translator/response.h +++ b/src/translator/response.h @@ -1,9 +1,10 @@ #ifndef SRC_BERGAMOT_RESPONSE_H_ #define SRC_BERGAMOT_RESPONSE_H_ -#include "sentence_ranges.h" +#include "data/alignment.h" #include "data/types.h" #include "definitions.h" +#include "sentence_ranges.h" #include "translator/beam_search.h" #include @@ -12,86 +13,87 @@ namespace marian { namespace bergamot { + +/// Alignment is stored as a sparse matrix, this pretty much aligns with marian +/// internals but is brought here to maintain translator +/// agnosticism/independence. +struct Point { + size_t src; ///< Index pointing to source ByteRange + size_t tgt; ///< Index pointing to target ByteRange + float prob; ///< Score between [0, 1] on indicating degree of alignment. +}; + +/// Alignment is a sparse matrix, where Points represent entries with values. +typedef std::vector Alignment; + +/// -loglikelhoods of the sequence components as proxy to quality. +struct Quality { + /// Certainty/uncertainty score for sequence. + float sequence; + /// Certainty/uncertainty for each word in the sequence. + std::vector word; +}; + +/// Response holds AnnotatedText(s) of source-text and translated text, +/// alignment information between source and target sub-words and sentences. +/// +/// AnnotatedText provides an API to access markings of (sub)-word and +/// sentences boundaries, which are required to interpret Quality and +/// Alignment (s) at the moment. class Response { - // Response is a marian internal class (not a bergamot-translator class) - // holding source blob of text, vector of TokenRanges corresponding to each - // sentence in the source text blob and histories obtained from translating - // these sentences. - // - // This class provides an API at a higher level in comparison to History to - // access translations and additionally use string_view manipulations to - // recover structure in translation from source-text's structure known through - // reference string and string_view. As many of these computations are not - // required until invoked, they are computed as required and stored in data - // members where it makes sense to do so (translation,translationTokenRanges). - // - // Examples of such use-cases are: - // translation() - // translationInSourceStructure() TODO(@jerinphilip) - // alignment(idx) TODO(@jerinphilip) - // sentenceMappings (for bergamot-translator) public: - Response(std::string &&source, SentenceRanges &&sourceRanges, - Histories &&histories, - // Required for constructing translation and TokenRanges within - // translation lazily. + /// + Response(AnnotatedText &&source, Histories &&histories, std::vector> &vocabs); + /// \cond HIDDEN_PUBLIC // Move constructor. Response(Response &&other) - : source_(std::move(other.source_)), - translation_(std::move(other.translation_)), - sourceRanges_(std::move(other.sourceRanges_)), - targetRanges_(std::move(other.targetRanges_)), - histories_(std::move(other.histories_)), - vocabs_(std::move(other.vocabs_)){}; + : source(std::move(other.source)), target(std::move(other.target)), + alignments(std::move(other.alignments)), + qualityScores(std::move(other.qualityScores)), + histories_(std::move(other.histories_)){}; + + // The following copy bans are not stricitly required anymore since Annotation + // is composed of the ByteRange primitive (which was previously string_view + // and required to be bound to string), but makes movement efficient by + // banning these letting compiler complain about copies. - // Prevents CopyConstruction and CopyAssignment. sourceRanges_ is constituted - // by string_view and copying invalidates the data member. Response(const Response &) = delete; Response &operator=(const Response &) = delete; - typedef std::vector> - SentenceMappings; + /// \endcond - // Moves source sentence into source, translated text into translation. - // Pairs of string_views to corresponding sentences in - // source and translation are loaded into sentenceMappings. These string_views - // reference the new source and translation. - // - // Calling move() invalidates the Response object as ownership is transferred. - // Exists for moving strc - void move(std::string &source, std::string &translation, - SentenceMappings &sentenceMappings); + /// Number of sentences translated. The processing of a text of into sentences + /// are handled internally, and this information can be used to iterate + /// through meaningful units of translation for which alignment and quality + /// information are available. + const size_t size() const { return source.numSentences(); } + /// source text and annotations of (sub-)words and sentences. + AnnotatedText source; + + /// translated text and annotations of (sub-)words and sentences. + AnnotatedText target; + + /// -logprob of each word and negative log likelihood of sequence (sentence) + /// normalized by length, for each sentence processed by the translator. + /// Indices correspond to ranges accessible through respective Annotation on + /// source or target. + std::vector qualityScores; + + /// Alignments between source and target. Each Alignment is a + /// sparse matrix representation with indices corresponding + /// to (sub-)words accessible through Annotation. + std::vector alignments; + + /// Access to histories, which holds rich information on translated text. + /// Not recommended to use, will be removed in future. const Histories &histories() const { return histories_; } - const std::string &source() const { return source_; } - const std::string &translation() { - constructTranslation(); - return translation_; - } - - // A convenience function provided to return translated text placed within - // source's structure. This is useful when the source text is a multi-line - // paragraph or string_views extracted from structured text like HTML and it's - // desirable to place the individual sentences in the locations of the source - // sentences. - // const std::string translationInSourceStructure(); - // const PendingAlignmentType alignment(size_t idx); private: - void constructTranslation(); - void constructSentenceMappings(SentenceMappings &); - - std::string source_; - SentenceRanges sourceRanges_; Histories histories_; - - std::vector> *vocabs_; - bool translationConstructed_{false}; - std::string translation_; - SentenceRanges targetRanges_; }; } // namespace bergamot } // namespace marian diff --git a/src/translator/sentence_ranges.cpp b/src/translator/sentence_ranges.cpp index a9ee8c5..053eeaa 100644 --- a/src/translator/sentence_ranges.cpp +++ b/src/translator/sentence_ranges.cpp @@ -5,40 +5,87 @@ namespace marian { namespace bergamot { -void SentenceRanges::addSentence(std::vector &wordRanges) { - addSentence(std::begin(wordRanges), std::end(wordRanges)); -} - -void SentenceRanges::addSentence(WordIterator begin, WordIterator end) { +void Annotation::addSentence(std::vector &sentence) { size_t size = flatByteRanges_.size(); - flatByteRanges_.insert(std::end(flatByteRanges_), begin, end); + flatByteRanges_.insert(std::end(flatByteRanges_), std::begin(sentence), + std::end(sentence)); sentenceBeginIds_.push_back(size); } -string_view SentenceRanges::sentence(size_t index) const { - size_t bos_id; - string_view eos, bos; - - bos_id = sentenceBeginIds_[index]; - bos = flatByteRanges_[bos_id]; - - if (index + 1 == numSentences()) { - eos = flatByteRanges_.back(); - } else { - assert(index < numSentences()); - size_t eos_id = sentenceBeginIds_[index + 1]; - --eos_id; - eos = flatByteRanges_[eos_id]; - } - - return sentenceBetween(bos, eos); +size_t Annotation::numWords(size_t sentenceIdx) const { + auto terminals = sentenceTerminalIds(sentenceIdx); + return terminals.second - terminals.first + 1; } -string_view SentenceRanges::sentenceBetween(string_view firstWord, - string_view lastWord) const { +std::pair +Annotation::sentenceTerminalIds(size_t sentenceIdx) const { + size_t bosId, eosId; + bosId = sentenceBeginIds_[sentenceIdx]; + eosId = sentenceIdx + 1 < numSentences() + ? sentenceBeginIds_[sentenceIdx + 1] - 1 + : flatByteRanges_.size() - 1; - const char *data = firstWord.data(); - size_t size = lastWord.data() + lastWord.size() - firstWord.data(); + // Out of bound checks. + assert(bosId < flatByteRanges_.size()); + assert(eosId < flatByteRanges_.size()); + return std::make_pair(bosId, eosId); +} + +std::pair +Annotation::sentenceTerminals(size_t sentenceIdx) const { + auto terminals = sentenceTerminalIds(sentenceIdx); + return std::make_pair(flatByteRanges_[terminals.first], + flatByteRanges_[terminals.second]); +} + +ByteRange Annotation::sentence(size_t sentenceIdx) const { + auto terminals = sentenceTerminals(sentenceIdx); + return (ByteRange){terminals.first.begin, terminals.second.end}; +} + +ByteRange Annotation::word(size_t sentenceIdx, size_t wordIdx) const { + size_t offset = sentenceBeginIds_[sentenceIdx]; + // auto terminals = sentenceTerminals(sentenceIdx); + // assert(offset + wordIdx <= terminals.second); + return flatByteRanges_[offset + wordIdx]; +} + +string_view AnnotatedText::word(size_t sentenceIdx, size_t wordIdx) const { + auto terminals = annotation.word(sentenceIdx, wordIdx); + return string_view(&text[terminals.begin], terminals.size()); +} + +string_view AnnotatedText::sentence(size_t sentenceIdx) const { + auto sentenceAsByteRange = annotation.sentence(sentenceIdx); + return asStringView(sentenceAsByteRange); +} + +void AnnotatedText::addSentence(std::vector &wordRanges) { + addSentence(std::begin(wordRanges), std::end(wordRanges)); +}; + +void AnnotatedText::addSentence(std::vector::iterator begin, + std::vector::iterator end) { + std::vector sentence; + for (auto p = begin; p != end; p++) { + size_t begin_offset = p->data() - &text[0]; + sentence.push_back((ByteRange){begin_offset, begin_offset + p->size()}); + } + annotation.addSentence(sentence); +}; + +ByteRange AnnotatedText::wordAsByteRange(size_t sentenceIdx, + size_t wordIdx) const { + return annotation.word(sentenceIdx, wordIdx); +} + +ByteRange AnnotatedText::sentenceAsByteRange(size_t sentenceIdx) const { + return annotation.sentence(sentenceIdx); +} + +string_view AnnotatedText::asStringView(const ByteRange &byteRange) const { + const char *data = &text[byteRange.begin]; + size_t size = byteRange.size(); return string_view(data, size); } diff --git a/src/translator/sentence_ranges.h b/src/translator/sentence_ranges.h index c6a0770..a0dc8c9 100644 --- a/src/translator/sentence_ranges.h +++ b/src/translator/sentence_ranges.h @@ -3,50 +3,134 @@ #include "data/types.h" #include +#include #include namespace marian { namespace bergamot { -class SentenceRanges { - // SentenceRanges stores string_views into a source text, with additional - // annotations to mark sentence boundaries. - // - // Given the availability annotations, this container provides capabilty to - // add sentences, and access individual sentences. +/// ByteRange stores indices for half-interval [begin, end) in a string. Can be +/// used to represent a sentence, word. +struct ByteRange { + size_t begin; + size_t end; + const size_t size() const { return end - begin; } +}; + +/// An Annotation is a collection of ByteRanges used to denote ancillary +/// information of sentences and words on a text of string. Annotation is meant +/// for consumption on platforms where string_view creates problems (eg: exports +/// through WASM). See AnnotatedText for cases where this is a non-issue. +class Annotation { public: - typedef std::vector::iterator WordIterator; - - void addSentence(std::vector &wordRanges); - void addSentence(WordIterator begin, WordIterator end); - - void clear() { - flatByteRanges_.clear(); - sentenceBeginIds_.clear(); - } + /// Annotation is constructed empty. See addSentence to populate it with + /// annotations. + Annotation() {} + /// Returns the number of sentences annotated in a text. size_t numSentences() const { return sentenceBeginIds_.size(); } - // Returns a string_view into the ith sentence. - string_view sentence(size_t index) const; + /// Returns number of words in the sentece identified by sentenceIdx. + size_t numWords(size_t sentenceIdx) const; + + /// Adds a sentences from vector representation, internally doing + /// extra book-keeping for the sentence terminal markings. Sentences are + /// expected to be added in order as they occur in text. + void addSentence(std::vector &sentence); + + /// Returns a ByteRange representing wordIdx in sentenceIdx + ByteRange word(size_t sentenceIdx, size_t wordIdx) const; + + /// Returns a ByteRange representing sentence corresponding to sentenceIdx. + ByteRange sentence(size_t sentenceIdx) const; private: - // A flat storage for string_views. Can be words or sentences. - std::vector flatByteRanges_; + /// A flat storage for ByteRanges. Composed of word ByteRanges, extra + /// information in sentenceBeginIds_ to denote sentence boundary markers as + /// indices. + std::vector flatByteRanges_; - // The container grows dynamically with addSentence. size_t marking index is - // used to ensure the sentence boundaries stay same while underlying storage - // might be changed during reallocation. + /// Stores indices where sentences begin std::vector sentenceBeginIds_; - // Utility function to extract the string starting at firstWord and ending at - // lastWord as a single string-view. - string_view sentenceBetween(string_view firstWord, - string_view lastWord) const; + /// Returns ByteRanges corresponding to beginning and end words of sentence + /// corresponding to sentenceIdx. This is useful in using the information to + /// construct a ByteRange of a sentence taking the begin from the first and + /// end from the second. + std::pair sentenceTerminals(size_t sentenceIdx) const; + + /// Returns indices of terminal (word) ByteRanges in sentenceIds_ of a + /// sentence corresponding to sentenceIdx. The distance can be used to compute + /// number of words in a sentence (numWords) and also to construct the + /// terminal ByteRanges (sentenceTerminals). + std::pair sentenceTerminalIds(size_t sentenceIdx) const; +}; + +/// AnnotatedText is effectively std::string text + Annotation, providing the +/// following additional desiderata. +/// +/// 1. Access to processed string_views for convenience rather than ByteRanges +/// (which only provides index information). +/// +/// 2. Transparently convert string_views into ByteRanges for the Annotation +/// referring to the text bound by this structure. +/// +/// 3. Bind the text and annotations together, to move around as a meaningful +/// unit. + +struct AnnotatedText { +public: + std::string text; ///< Blob of string elements in annotation refers to. + Annotation annotation; ///< sentence and (sub-) word annotations. + + /// Construct an empty AnnotatedText. This is useful when the target string or + /// ByteRanges are not known yet, but the public members can be used to + /// populate it. One use-case, when translated-text is created decoding from + /// histories and the ByteRanges only known after the string has been + /// constructed. + AnnotatedText() {} + + /// Construct moving in a string (for efficiency purposes, copying string + /// constructor is disallowed). + AnnotatedText(std::string &&text) : text(std::move(text)){}; + + AnnotatedText(AnnotatedText &&annotatedBlob) + : text(std::move(annotatedBlob.text)), + annotation(std::move(annotatedBlob.annotation)) {} + + /// Returns the number of sentences in the annotation structure. + const size_t numSentences() const { return annotation.numSentences(); } + + /// Returns number of words in the sentece identified by sentenceIdx. + const size_t numWords(size_t sentenceIdx) const { + return annotation.numWords(sentenceIdx); + } + + /// Adds a sentence, used to load from SentencePiece annotations conveniently. + void addSentence(std::vector &wordRanges); + + /// Adds a sentence between two iterators, often useful while constructing + /// from parts of a container. + void addSentence(std::vector::iterator begin, + std::vector::iterator end); + + /// Returns a string_view representing wordIdx in sentenceIdx + string_view word(size_t sentenceIdx, size_t wordIdx) const; + + /// Returns a string_view representing sentence corresponding to sentenceIdx. + string_view sentence(size_t sentenceIdx) const; + + /// Returns a ByteRange representing wordIdx in sentenceIdx + ByteRange wordAsByteRange(size_t sentenceIdx, size_t wordIdx) const; + + /// Returns a ByteRange representing sentence corresponding to sentenceIdx. + ByteRange sentenceAsByteRange(size_t sentenceIdx) const; + +private: + string_view asStringView(const ByteRange &byteRange) const; }; } // namespace bergamot - } // namespace marian #endif // BERGAMOT_SENTENCE_RANGES_H_ diff --git a/src/translator/service.cpp b/src/translator/service.cpp index 8c8c453..80b2a10 100644 --- a/src/translator/service.cpp +++ b/src/translator/service.cpp @@ -112,15 +112,15 @@ void Service::async_translate() { std::future Service::translate(std::string &&input) { Segments segments; - SentenceRanges sourceRanges; - text_processor_.process(input, segments, sourceRanges); + AnnotatedText source(std::move(input)); + text_processor_.process(source, segments); std::promise responsePromise; auto future = responsePromise.get_future(); Ptr request = New( - requestId_++, /* lineNumberBegin = */ 0, vocabs_, std::move(input), - std::move(segments), std::move(sourceRanges), std::move(responsePromise)); + requestId_++, /* lineNumberBegin = */ 0, vocabs_, std::move(source), + std::move(segments), std::move(responsePromise)); batcher_.addWholeRequest(request); if (numWorkers_ == 0) { diff --git a/src/translator/service.h b/src/translator/service.h index bb8dbe9..04ec2b8 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -25,9 +25,9 @@ namespace bergamot { /// /// options = ...; /// service = Service(options); -/// std::string input_blob = "Hello World"; +/// std::string input_text = "Hello World"; /// std::future -/// response = service.translate(std::move(input_blob)); +/// response = service.translate(std::move(input_text)); /// response.wait(); /// Response result = response.get(); /// diff --git a/src/translator/text_processor.cpp b/src/translator/text_processor.cpp index 9d6733e..8d7f25c 100644 --- a/src/translator/text_processor.cpp +++ b/src/translator/text_processor.cpp @@ -25,9 +25,9 @@ TextProcessor::TextProcessor(std::vector> &vocabs, ABORT_IF(max_length_break_ < 0, "max-length-break cannot be < 0"); } -void TextProcessor::process(const string_view &query, Segments &segments, - SentenceRanges &sourceRanges) { +void TextProcessor::process(AnnotatedText &source, Segments &segments) { + string_view query = string_view(source.text); auto sentenceStream = sentence_splitter_.createSentenceStream(query); std::string_view sentenceStringPiece; @@ -42,14 +42,14 @@ void TextProcessor::process(const string_view &query, Segments &segments, // after normalization. 0 prevents any empty entries from being added. if (segment.size() > 0) { // Truncate segment into max_input_size segments. - truncate(segment, wordRanges, segments, sourceRanges); + truncate(segment, wordRanges, segments, source); } } } void TextProcessor::truncate(Segment &segment, std::vector &wordRanges, - Segments &segments, SentenceRanges &sourceRanges) { + Segments &segments, AnnotatedText &source) { for (size_t offset = 0; offset < segment.size(); offset += max_length_break_) { auto start = segment.begin() + offset; @@ -61,7 +61,7 @@ void TextProcessor::truncate(Segment &segment, segments.back().push_back(sourceEosId()); auto astart = wordRanges.begin() + offset; - sourceRanges.addSentence(astart, astart + diff); + source.addSentence(astart, astart + diff); } } diff --git a/src/translator/text_processor.h b/src/translator/text_processor.h index 4cd1761..ed3c773 100644 --- a/src/translator/text_processor.h +++ b/src/translator/text_processor.h @@ -23,8 +23,7 @@ class TextProcessor { public: explicit TextProcessor(std::vector> &vocabs, Ptr); - void process(const string_view &query, Segments &segments, - SentenceRanges &sourceRanges); + void process(AnnotatedText &source, Segments &segments); private: // Tokenizes an input string, returns Words corresponding. Loads the @@ -34,7 +33,7 @@ private: // Truncate sentence into max_input_size segments. void truncate(Segment &sentence, std::vector &tokenRanges, - Segments &segments, SentenceRanges &sourceRanges); + Segments &segments, AnnotatedText &source); // shorthand, used only in truncate() const Word sourceEosId() const { return vocabs_->front()->getEosId(); }