diff --git a/src/tests/common-impl.cpp b/src/tests/common-impl.cpp index 43b1c07..431ddaa 100644 --- a/src/tests/common-impl.cpp +++ b/src/tests/common-impl.cpp @@ -155,10 +155,11 @@ void TestSuite::qualityEstimatorWords(Ptr model) { std::string source = readFromStdin(); const Response response = bridge_.translate(service_, model, std::move(source), responseOptions); - for (const auto &sentenceQualityEstimate : response.qualityScores) { + for (size_t sentenceIdx = 0; sentenceIdx < response.qualityScores.size(); ++sentenceIdx) { + const auto &sentenceQualityEstimate = response.qualityScores[sentenceIdx]; std::cout << "[SentenceBegin]\n"; - for (const auto &wordByteRange : sentenceQualityEstimate.wordByteRanges) { + for (const auto &wordByteRange : getWordByteRanges(response, sentenceIdx)) { const string_view word(response.target.text.data() + wordByteRange.begin, wordByteRange.size()); std::cout << word << "\n"; } diff --git a/src/translator/definitions.h b/src/translator/definitions.h index eb1e672..b3bc101 100644 --- a/src/translator/definitions.h +++ b/src/translator/definitions.h @@ -42,6 +42,16 @@ struct ByteRange { bool operator==(ByteRange other) const { return begin == other.begin && end == other.end; } }; +/// A Subword range is mechanically the same as a `ByteRange`, but instead of +/// describing a span of bytes, it describes a span of Subword tokens. Using +/// `Annotation.word()` you can switch between the two. +struct SubwordRange { + size_t begin; + size_t end; + const size_t size() const { return end - begin; } + bool operator==(SubwordRange other) const { return begin == other.begin && end == other.end; } +}; + class Response; using CallbackType = std::function; diff --git a/src/translator/html.cpp b/src/translator/html.cpp index ed42b91..5180a5a 100644 --- a/src/translator/html.cpp +++ b/src/translator/html.cpp @@ -3,6 +3,7 @@ #include #include "response.h" +#include "translator/definitions.h" #include "xh_scanner.h" namespace { @@ -544,7 +545,12 @@ void HTML::restore(Response &response) { copyTagStack(response, alignments, sourceTokenSpans, targetTokenSpans); assert(targetTokenSpans.size() == debugCountTokens(response.target)); - AnnotatedText target = restoreTarget(response.target, targetTokenSpans); + // Take the spans, and use them to make a taint for every word in the + // translation. Optionally add extra tags, like quality score metadata. + std::vector targetTokenTags; + annotateTagStack(response, targetTokenSpans, targetTokenTags); + + AnnotatedText target = restoreTarget(response.target, targetTokenSpans, targetTokenTags); response.source = source; response.target = target; @@ -592,38 +598,37 @@ AnnotatedText HTML::restoreSource(AnnotatedText const &in, std::vector const &targetTokenSpans) { - auto prevSpan = spans_.cbegin(); +AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector const &targetTokenSpans, + std::vector const &targetTokenTags) { + auto prevTags = spans_.cbegin()->tags; + auto stragglerSpanIt = spans_.cbegin(); auto targetSpanIt = targetTokenSpans.begin(); - auto straggerSpanIt = spans_.cbegin(); + auto targetTagIt = targetTokenTags.begin(); AnnotatedText out = in.apply([&]([[maybe_unused]] ByteRange range, string_view token, bool last) { TokenFormatter formatter(token); // First we scan through spans_ to catch up to the span assigned to this // token. We're only interested in empty spans (empty and void elements) - for (; straggerSpanIt < *targetSpanIt; ++straggerSpanIt) { + for (; stragglerSpanIt < *targetSpanIt; stragglerSpanIt++) { // We're only interested in empty spans or spans that would otherwise get // lost because they didn't align with anything between the spans in // targetSpanIt // TODO That std::find makes this O(N*N) NOT GOOD NOT GOOD - if (straggerSpanIt->size() != 0 && - std::find(targetTokenSpans.begin(), targetTokenSpans.end(), straggerSpanIt) != targetTokenSpans.end()) + if (stragglerSpanIt->size() != 0 && + std::find(targetTokenSpans.begin(), targetTokenSpans.end(), stragglerSpanIt) != targetTokenSpans.end()) continue; - formatter.append(prevSpan->tags, straggerSpanIt->tags); - - // Note: here, not in 3rd part of for-statement because we don't want to - // set prevSpan if the continue clause at the beginning of this for-loop - // was hit. - prevSpan = straggerSpanIt; + formatter.append(prevTags, stragglerSpanIt->tags); + prevTags = stragglerSpanIt->tags; } // Now do the same thing but for our target set of tags. Note that we cannot // combine this in the for-loop above (i.e. `span_it <= *targetSpanIt`) // because there is no guarantee that the order in `targetTokenSpans` is // the same as that of `spans`. - formatter.append(prevSpan->tags, (*targetSpanIt)->tags); + + formatter.append(prevTags, *targetTagIt); // If this is the last token of the response, close all open tags. if (last) { @@ -632,11 +637,12 @@ AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vectortags.empty()); - formatter.append((*targetSpanIt)->tags, HTML::TagStack()); + formatter.append(*targetTagIt, HTML::TagStack()); } - prevSpan = *targetSpanIt; + prevTags = *targetTagIt; ++targetSpanIt; + ++targetTagIt; return std::move(formatter.html()); }); @@ -674,6 +680,56 @@ void HTML::copyTagStack(Response const &response, std::vector const &targetTokenSpans, + std::vector &targetTokenTags) { + auto spanIt = targetTokenSpans.begin(); + for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) { + // Sentence prefix + targetTokenTags.push_back((*spanIt)->tags); + spanIt++; + + // Offset in targetTokenTags at which this sentence's tags start. + size_t tagOffset = targetTokenTags.size(); + + // Initially, just copy the span's tags to this token + for (size_t t = 0; t < response.target.numWords(sentenceIdx); ++t) { + targetTokenTags.emplace_back((*spanIt)->tags); + spanIt++; + } + + // If we have quality score information, add that as metadata as well. + if (!response.qualityScores.empty()) { + auto const &sentenceQuality = response.qualityScores[sentenceIdx]; + // Create a single tag for this sentence with sentence level info + Tag *sentenceTag = makeTag({Tag::ELEMENT, "font"}); + sentenceTag->attributes += format(" x-bergamot-sentence-index=\"{}\" x-bergamot-sentence-score=\"{}\"", + sentenceIdx, sentenceQuality.sentenceScore); + + // Add that tag to all tokens in this sentence. + for (size_t tokenIdx = 0; tokenIdx < response.target.numWords(sentenceIdx); ++tokenIdx) { + targetTokenTags[tagOffset + tokenIdx].push_back(sentenceTag); + } + + // Add word level tags as well to all tokens that make up a word. + for (size_t wordIdx = 0; wordIdx < sentenceQuality.wordRanges.size(); ++wordIdx) { + Tag *wordTag = makeTag({Tag::ELEMENT, "font"}); + wordTag->attributes += format(" x-bergamot-word-index=\"{}\" x-bergamot-word-score=\"{}\"", wordIdx, + sentenceQuality.wordScores[wordIdx]); + auto const &range = sentenceQuality.wordRanges[wordIdx]; + for (size_t tokenIdx = range.begin; tokenIdx < range.end; ++tokenIdx) { + targetTokenTags[tagOffset + tokenIdx].push_back(wordTag); + } + } + } + } + + // Suffix + targetTokenTags.push_back((*spanIt)->tags); + spanIt++; + + assert(spanIt == targetTokenSpans.end()); +} + // Reports if token `str` is likely to be a continuation of a word. This is used // to determine whether we should share the markup, or whether we should see // this token as a fresh start. This implementation will treat "hello[world]" diff --git a/src/translator/html.h b/src/translator/html.h index c704c59..f3c6dad 100644 --- a/src/translator/html.h +++ b/src/translator/html.h @@ -162,7 +162,7 @@ class HTML { void restore(Response &response); private: - using SpanIterator = std::vector::const_iterator; + using SpanIterator = std::vector::iterator; using AnnotatedText = marian::bergamot::AnnotatedText; /// Reconstructs HTML in `response.source` (passed as `in`) and makes a list @@ -175,7 +175,8 @@ class HTML { /// Inserts the HTML into `response.target` (passed as `in`) based on /// `targetTokenSpans`, which points to a `Span` for each token (subword) in /// `response.target`. - AnnotatedText restoreTarget(AnnotatedText const &in, std::vector const &targetTokenSpans); + AnnotatedText restoreTarget(AnnotatedText const &in, std::vector const &targetTokenSpans, + std::vector const &targetTokenTags); /// Utilities to test whether subword `str` is part of a word together with /// the subword `prev`, or a separate word. Basically *does `str` start with @@ -190,6 +191,9 @@ class HTML { std::vector const &sourceTokenSpans, std::vector &targetTokenSpans); + void annotateTagStack(Response const &response, std::vector const &targetTokenSpans, + std::vector &targetTokenTags); + /// Turns the alignment scores in `response.alignments` into one source token /// per target token. Has some heuristics to keep all target tokens of a /// single word pointing to the same span, and prefers spans with more markup diff --git a/src/translator/quality_estimator.cpp b/src/translator/quality_estimator.cpp index 936d293..24ca2c2 100644 --- a/src/translator/quality_estimator.cpp +++ b/src/translator/quality_estimator.cpp @@ -27,7 +27,7 @@ Response::SentenceQualityScore UnsupervisedQualityEstimator::computeSentenceScor const float sentenceScore = std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size(); - return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore}; + return {wordScores, wordIndices, sentenceScore}; } LogisticRegressorQualityEstimator::Matrix::Matrix(const size_t rowsParam, const size_t colsParam) @@ -160,7 +160,7 @@ Response::SentenceQualityScore LogisticRegressorQualityEstimator::computeSentenc const float sentenceScore = std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size(); - return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore}; + return {wordScores, wordIndices, sentenceScore}; } std::vector LogisticRegressorQualityEstimator::predict(const Matrix& features) const { @@ -267,22 +267,4 @@ std::vector mapWords(const std::vector& logProbs, const Ann return wordIndices; } -std::vector subwordToWords(const std::vector& wordIndices, const AnnotatedText& target, - const size_t sentenceIdx) { - std::vector words; - - for (const SubwordRange& wordIndice : wordIndices) { - size_t wordBegin = target.wordAsByteRange(sentenceIdx, wordIndice.begin).begin; - size_t wordEnd = target.wordAsByteRange(sentenceIdx, wordIndice.end).begin; - - if (isspace(target.text.at(wordBegin))) { - ++wordBegin; - } - - words.emplace_back(ByteRange{wordBegin, wordEnd}); - } - - return words; -} - } // namespace marian::bergamot diff --git a/src/translator/quality_estimator.h b/src/translator/quality_estimator.h index 3d2fd68..b8d1596 100644 --- a/src/translator/quality_estimator.h +++ b/src/translator/quality_estimator.h @@ -21,8 +21,6 @@ class QualityEstimator { virtual void computeQualityScores(const Histories &histories, Response &response) const = 0; }; -using SubwordRange = ByteRange; - /// Unsupervised Quality Estimator model. It uses the translator model's log probabilities (log probs) as a proxy for /// quality scores. Then, for a given word, its quality score is computed by taking the mean of the log probs of the /// tokens that make it up. The sentence score is the mean of all word's log probs. @@ -209,14 +207,4 @@ inline std::shared_ptr createQualityEstimator(const AlignedMem std::vector mapWords(const std::vector &logProbs, const AnnotatedText &target, const size_t sentenceIdx); -/// Given a vector of subwordRanges, it maps the elements to be real words rather than sublevel tokens. The words are -/// represented through ByteRanges. - -/// @param [in] wordIndices: A vector where each element correspond to the index of a real word and its values are -/// represented by the SubwordRanges (which are aliases of ByteRanges) which represents sublevel token positions -/// @param [in] target: AnnotatedText target value -/// @param [in] sentenceIdx: the id of a candidate sentence -std::vector subwordToWords(const std::vector &wordIndices, const AnnotatedText &target, - const size_t sentenceIdx); - } // namespace marian::bergamot diff --git a/src/translator/response.cpp b/src/translator/response.cpp index 8e623a7..135ec47 100644 --- a/src/translator/response.cpp +++ b/src/translator/response.cpp @@ -142,4 +142,22 @@ std::vector remapAlignments(const Response &first, const Response &se return alignments; } +std::vector getWordByteRanges(const Response &response, size_t sentenceIdx) { + std::vector wordByteRanges; + wordByteRanges.reserve(response.qualityScores[sentenceIdx].wordRanges.size()); + + for (auto &&word : response.qualityScores[sentenceIdx].wordRanges) { + size_t wordBegin = response.target.wordAsByteRange(sentenceIdx, word.begin).begin; + size_t wordEnd = response.target.wordAsByteRange(sentenceIdx, word.end).begin; + + if (std::isspace(response.target.text.at(wordBegin))) { + ++wordBegin; + } + + wordByteRanges.emplace_back(ByteRange{wordBegin, wordEnd}); + } + + return wordByteRanges; +} + } // namespace marian::bergamot diff --git a/src/translator/response.h b/src/translator/response.h index 74463ed..af05e10 100644 --- a/src/translator/response.h +++ b/src/translator/response.h @@ -30,8 +30,8 @@ struct Response { struct SentenceQualityScore { /// Quality score of each translated word std::vector wordScores; - /// Each word position in the translated text - std::vector wordByteRanges; + /// Position of start and end token of each word in the translated text + std::vector wordRanges; /// Whole sentence quality score (it is composed by the mean of its words) float sentenceScore = 0.0; }; @@ -77,6 +77,8 @@ struct Response { std::vector remapAlignments(const Response &first, const Response &second); +std::vector getWordByteRanges(Response const &response, size_t sentenceIdx); + } // namespace bergamot } // namespace marian diff --git a/wasm/bindings/response_bindings.cpp b/wasm/bindings/response_bindings.cpp index 11bc4ca..51a46ab 100644 --- a/wasm/bindings/response_bindings.cpp +++ b/wasm/bindings/response_bindings.cpp @@ -10,7 +10,6 @@ #include "response.h" using Response = marian::bergamot::Response; -using SentenceQualityScore = marian::bergamot::Response::SentenceQualityScore; using ByteRange = marian::bergamot::ByteRange; using namespace emscripten; @@ -20,25 +19,14 @@ EMSCRIPTEN_BINDINGS(byte_range) { value_object("ByteRange").field("begin", &ByteRange::begin).field("end", &ByteRange::end); } -std::vector getQualityScores(const Response& response) { return response.qualityScores; } - EMSCRIPTEN_BINDINGS(response) { class_("Response") .constructor<>() .function("size", &Response::size) - .function("getQualityScores", &getQualityScores) .function("getOriginalText", &Response::getOriginalText) .function("getTranslatedText", &Response::getTranslatedText) .function("getSourceSentence", &Response::getSourceSentenceAsByteRange) .function("getTranslatedSentence", &Response::getTargetSentenceAsByteRange); - value_object("SentenceQualityScore") - .field("wordScores", &SentenceQualityScore::wordScores) - .field("wordByteRanges", &SentenceQualityScore::wordByteRanges) - .field("sentenceScore", &SentenceQualityScore::sentenceScore); - register_vector("VectorResponse"); - register_vector("VectorSentenceQualityScore"); - register_vector("VectorFloat"); - register_vector("VectorByteRange"); } diff --git a/wasm/test_page/css/index.css b/wasm/test_page/css/index.css index bbc5bf1..6ed6422 100644 --- a/wasm/test_page/css/index.css +++ b/wasm/test_page/css/index.css @@ -73,7 +73,7 @@ label { align-self: center; } -textarea { +textarea, .output-area { padding: 1rem; font-family: sans-serif; font-size: 1rem; @@ -97,3 +97,22 @@ button:hover { #output { background-color: #f4f4f4; } + +.output-area [x-bergamot-word-score].bad { + background-image: + linear-gradient(45deg, transparent 65%, red 80%, transparent 90%), + linear-gradient(135deg, transparent 5%, red 15%, transparent 25%), + linear-gradient(135deg, transparent 45%, red 55%, transparent 65%), + linear-gradient(45deg, transparent 25%, red 35%, transparent 50%); + background-repeat:repeat-x; + background-size: 8px 2px; + background-position:0 95%; +} + +.output-area [x-bergamot-sentence-score].bad { + background: rgba(255, 128, 128, 0.8); +} + +.output-area [x-bergamot-sentence-index].highlight-sentence { + background: rgba(255, 255, 128, 0.8); +} \ No newline at end of file diff --git a/wasm/test_page/index.html b/wasm/test_page/index.html index 86eae46..3f48117 100644 --- a/wasm/test_page/index.html +++ b/wasm/test_page/index.html @@ -24,7 +24,7 @@ To - +
diff --git a/wasm/test_page/js/index.js b/wasm/test_page/js/index.js index 7cb3657..166f0f2 100644 --- a/wasm/test_page/js/index.js +++ b/wasm/test_page/js/index.js @@ -38,22 +38,58 @@ const _prepareTranslateOptions = (paragraphs) => { return translateOptions; }; +const textToHTML = (text) => { + const div = document.createElement('div'); + div.appendChild(document.createTextNode(text)); + return div.innerHTML; +}; + const translateCall = () => { - const text = document.querySelector("#input").value + " "; + const text = document.querySelector("#input").value; if (!text.trim().length) return; - const paragraphs = text.split("\n"); + const paragraphs = text.split(/\n+/).map(textToHTML); // escape HTML const translateOptions = _prepareTranslateOptions(paragraphs); - $("#output").setAttribute("disabled", true); const lngFrom = langFrom.value; const lngTo = langTo.value; worker.postMessage(["translate", lngFrom, lngTo, paragraphs, translateOptions]); }; +const addQualityClasses = (root) => { + // You can do this wit CSS variables, calc() and min/max, but JS is just easier + + root.querySelectorAll('[x-bergamot-sentence-score]').forEach(el => { + // Note: these thresholds are just examples, they are not good thresholds! + el.classList.toggle('bad', parseFloat(el.getAttribute('x-bergamot-sentence-score')) > -0.1); + }); + + root.querySelectorAll('[x-bergamot-word-score]').forEach(el => { + // Note: these thresholds are just examples, they are not good thresholds! + el.classList.toggle('bad', parseFloat(el.getAttribute('x-bergamot-word-score')) > -0.1); + }); + + // Add tooltips to each (sub)word with sentence and word score. + root.querySelectorAll('[x-bergamot-sentence-score] > [x-bergamot-word-score]').forEach(el => { + const sentenceScore = parseFloat(el.parentNode.getAttribute('x-bergamot-sentence-score')); + const wordScore = parseFloat(el.getAttribute('x-bergamot-word-score')); + el.title = `Sentence: ${sentenceScore} Word: ${wordScore}`; + }); +} + worker.onmessage = function (e) { if (e.data[0] === "translate_reply" && e.data[1]) { - document.querySelector("#output").value = e.data[1].join("\n\n"); - $("#output").removeAttribute("disabled"); + // Clear output of previous translation + document.querySelector("#output").innerHTML = ''; + + // Add each translation in its own div to have a known root in which the + // sentence ids are unique. Used for highlighting sentences. + e.data[1].forEach(translatedHTML => { + const translation = document.createElement('div'); + translation.classList.add('translation'); + translation.innerHTML = translatedHTML; + addQualityClasses(translation); + document.querySelector("#output").appendChild(translation); + }); } else if (e.data[0] === "load_model_reply" && e.data[1]) { status(e.data[1]); translateCall(); @@ -76,8 +112,8 @@ const loadModel = () => { console.log(`Loading model '${lngFrom}${lngTo}'`); worker.postMessage(["load_model", lngFrom, lngTo]); } else { - const input = document.querySelector("#input").value; - document.querySelector("#output").value = input; + const input = textToHTML(document.querySelector("#input").value); + document.querySelector("#output").innerHTML = input; } }; @@ -95,6 +131,14 @@ $(".swap").addEventListener("click", e => { loadModel(); }); +$('#output').addEventListener('mouseover', e => { + const root = e.target.closest('.translation'); + const sentence = e.target.parentNode.hasAttribute('x-bergamot-sentence-index') ? e.target.parentNode.getAttribute('x-bergamot-sentence-index') : null; + document.querySelectorAll('#output font[x-bergamot-sentence-index]').forEach(el => { + el.classList.toggle('highlight-sentence', el.getAttribute('x-bergamot-sentence-index') === sentence && el.closest('.translation') === root); + }) +}) + function init() { // try to guess input language from user agent let myLang = navigator.language; diff --git a/wasm/test_page/js/worker.js b/wasm/test_page/js/worker.js index 292e2d6..aa4d404 100644 --- a/wasm/test_page/js/worker.js +++ b/wasm/test_page/js/worker.js @@ -137,13 +137,11 @@ const translate = (from, to, input, translateOptions) => { const listSourceText = _parseSourceText(vectorResponse); const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse); const listSourceTextSentences = _parseSourceTextSentences(vectorResponse); - const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse); log(`Source text: ${listSourceText}`); log(`Translated text: ${listTranslatedText}`); log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`); log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`); - log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`); return listTranslatedText; } finally { @@ -292,44 +290,6 @@ const _parseSourceTextSentences = (vectorResponse) => { return result; } -const _parseTranslatedTextSentenceQualityScores = (vectorResponse) => { - const result = []; - for (let i = 0; i < vectorResponse.size(); i++) { - const response = vectorResponse.get(i); - const translatedText = response.getTranslatedText(); - const vectorSentenceQualityScore = response.getQualityScores(); - log(`No. of sentences: "${vectorSentenceQualityScore.size()}"`); - const sentenceQualityScores = []; - for (let sentenceIndex=0; sentenceIndex < vectorSentenceQualityScore.size(); sentenceIndex++) { - const sentenceQualityScoreObject = vectorSentenceQualityScore.get(sentenceIndex); - const wordByteRangeList = []; - const wordList = []; - const wordScoreList = []; - const vectorWordScore = sentenceQualityScoreObject.wordScores; - const vectorWordByteRange = sentenceQualityScoreObject.wordByteRanges; - - for (let wordIndex = 0; wordIndex < vectorWordScore.size(); wordIndex++) { - const wordScore = vectorWordScore.get(wordIndex); - const wordByteRange = vectorWordByteRange.get(wordIndex); - wordScoreList.push(wordScore); - wordByteRangeList.push(wordByteRange); - const word = _getSubString(translatedText, wordByteRange); - wordList.push(word); - } - - const sentenceQualityScore = { - wordByteRanges: wordByteRangeList, - words: wordList, - wordScores: wordScoreList, - sentenceScore: sentenceQualityScoreObject.sentenceScore - }; - sentenceQualityScores.push(sentenceQualityScore); - } - result.push(sentenceQualityScores); - } - return result; -} - const _prepareResponseOptions = (translateOptions) => { let vectorResponseOptions = new Module.VectorResponseOptions; translateOptions.forEach(translateOption => {