Embed quality-scores as HTML tag attributes (#358)

Quality scores for HTML translation exposed as <font
x-bergamot-sentence-score=""> and <font x-bergamot-word-score=""> tags
in the HTML output. While this increases the size of the HTML returned,
the resulting rendered HTML can easily be styled to show the scores.
With Javascript or CSS, developers can easily have some interface based
on these extra attributes.

Also includes updates to the test page to show a proof-of-concept 
demonstration.

Fixes: #355
This commit is contained in:
Jelmer 2022-02-25 23:01:32 +01:00 committed by GitHub
parent 96b0f82343
commit fe3f3982de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 187 additions and 115 deletions

View File

@ -155,10 +155,11 @@ void TestSuite<Service>::qualityEstimatorWords(Ptr<TranslationModel> model) {
std::string source = readFromStdin();
const Response response = bridge_.translate(service_, model, std::move(source), responseOptions);
for (const auto &sentenceQualityEstimate : response.qualityScores) {
for (size_t sentenceIdx = 0; sentenceIdx < response.qualityScores.size(); ++sentenceIdx) {
const auto &sentenceQualityEstimate = response.qualityScores[sentenceIdx];
std::cout << "[SentenceBegin]\n";
for (const auto &wordByteRange : sentenceQualityEstimate.wordByteRanges) {
for (const auto &wordByteRange : getWordByteRanges(response, sentenceIdx)) {
const string_view word(response.target.text.data() + wordByteRange.begin, wordByteRange.size());
std::cout << word << "\n";
}

View File

@ -42,6 +42,16 @@ struct ByteRange {
bool operator==(ByteRange other) const { return begin == other.begin && end == other.end; }
};
/// A Subword range is mechanically the same as a `ByteRange`, but instead of
/// describing a span of bytes, it describes a span of Subword tokens. Using
/// `Annotation.word()` you can switch between the two.
struct SubwordRange {
size_t begin;
size_t end;
const size_t size() const { return end - begin; }
bool operator==(SubwordRange other) const { return begin == other.begin && end == other.end; }
};
class Response;
using CallbackType = std::function<void(Response &&)>;

View File

@ -3,6 +3,7 @@
#include <algorithm>
#include "response.h"
#include "translator/definitions.h"
#include "xh_scanner.h"
namespace {
@ -544,7 +545,12 @@ void HTML::restore(Response &response) {
copyTagStack(response, alignments, sourceTokenSpans, targetTokenSpans);
assert(targetTokenSpans.size() == debugCountTokens(response.target));
AnnotatedText target = restoreTarget(response.target, targetTokenSpans);
// Take the spans, and use them to make a taint for every word in the
// translation. Optionally add extra tags, like quality score metadata.
std::vector<HTML::TagStack> targetTokenTags;
annotateTagStack(response, targetTokenSpans, targetTokenTags);
AnnotatedText target = restoreTarget(response.target, targetTokenSpans, targetTokenTags);
response.source = source;
response.target = target;
@ -592,38 +598,37 @@ AnnotatedText HTML::restoreSource(AnnotatedText const &in, std::vector<SpanItera
});
}
AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans) {
auto prevSpan = spans_.cbegin();
AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<TagStack> const &targetTokenTags) {
auto prevTags = spans_.cbegin()->tags;
auto stragglerSpanIt = spans_.cbegin();
auto targetSpanIt = targetTokenSpans.begin();
auto straggerSpanIt = spans_.cbegin();
auto targetTagIt = targetTokenTags.begin();
AnnotatedText out = in.apply([&]([[maybe_unused]] ByteRange range, string_view token, bool last) {
TokenFormatter formatter(token);
// First we scan through spans_ to catch up to the span assigned to this
// token. We're only interested in empty spans (empty and void elements)
for (; straggerSpanIt < *targetSpanIt; ++straggerSpanIt) {
for (; stragglerSpanIt < *targetSpanIt; stragglerSpanIt++) {
// We're only interested in empty spans or spans that would otherwise get
// lost because they didn't align with anything between the spans in
// targetSpanIt
// TODO That std::find makes this O(N*N) NOT GOOD NOT GOOD
if (straggerSpanIt->size() != 0 &&
std::find(targetTokenSpans.begin(), targetTokenSpans.end(), straggerSpanIt) != targetTokenSpans.end())
if (stragglerSpanIt->size() != 0 &&
std::find(targetTokenSpans.begin(), targetTokenSpans.end(), stragglerSpanIt) != targetTokenSpans.end())
continue;
formatter.append(prevSpan->tags, straggerSpanIt->tags);
// Note: here, not in 3rd part of for-statement because we don't want to
// set prevSpan if the continue clause at the beginning of this for-loop
// was hit.
prevSpan = straggerSpanIt;
formatter.append(prevTags, stragglerSpanIt->tags);
prevTags = stragglerSpanIt->tags;
}
// Now do the same thing but for our target set of tags. Note that we cannot
// combine this in the for-loop above (i.e. `span_it <= *targetSpanIt`)
// because there is no guarantee that the order in `targetTokenSpans` is
// the same as that of `spans`.
formatter.append(prevSpan->tags, (*targetSpanIt)->tags);
formatter.append(prevTags, *targetTagIt);
// If this is the last token of the response, close all open tags.
if (last) {
@ -632,11 +637,12 @@ AnnotatedText HTML::restoreTarget(AnnotatedText const &in, std::vector<SpanItera
// the last token of the output. But lets assume someone someday changes
// HardAlignments(), and then this for-loop will be necessary.
// assert((*targetSpanIt)->tags.empty());
formatter.append((*targetSpanIt)->tags, HTML::TagStack());
formatter.append(*targetTagIt, HTML::TagStack());
}
prevSpan = *targetSpanIt;
prevTags = *targetTagIt;
++targetSpanIt;
++targetTagIt;
return std::move(formatter.html());
});
@ -674,6 +680,56 @@ void HTML::copyTagStack(Response const &response, std::vector<std::vector<size_t
targetTokenSpans.push_back(sourceTokenSpans[offset]); // token_tag for ending whitespace
}
void HTML::annotateTagStack(Response const &response, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> &targetTokenTags) {
auto spanIt = targetTokenSpans.begin();
for (size_t sentenceIdx = 0; sentenceIdx < response.target.numSentences(); ++sentenceIdx) {
// Sentence prefix
targetTokenTags.push_back((*spanIt)->tags);
spanIt++;
// Offset in targetTokenTags at which this sentence's tags start.
size_t tagOffset = targetTokenTags.size();
// Initially, just copy the span's tags to this token
for (size_t t = 0; t < response.target.numWords(sentenceIdx); ++t) {
targetTokenTags.emplace_back((*spanIt)->tags);
spanIt++;
}
// If we have quality score information, add that as metadata as well.
if (!response.qualityScores.empty()) {
auto const &sentenceQuality = response.qualityScores[sentenceIdx];
// Create a single <font> tag for this sentence with sentence level info
Tag *sentenceTag = makeTag({Tag::ELEMENT, "font"});
sentenceTag->attributes += format(" x-bergamot-sentence-index=\"{}\" x-bergamot-sentence-score=\"{}\"",
sentenceIdx, sentenceQuality.sentenceScore);
// Add that tag to all tokens in this sentence.
for (size_t tokenIdx = 0; tokenIdx < response.target.numWords(sentenceIdx); ++tokenIdx) {
targetTokenTags[tagOffset + tokenIdx].push_back(sentenceTag);
}
// Add word level <font> tags as well to all tokens that make up a word.
for (size_t wordIdx = 0; wordIdx < sentenceQuality.wordRanges.size(); ++wordIdx) {
Tag *wordTag = makeTag({Tag::ELEMENT, "font"});
wordTag->attributes += format(" x-bergamot-word-index=\"{}\" x-bergamot-word-score=\"{}\"", wordIdx,
sentenceQuality.wordScores[wordIdx]);
auto const &range = sentenceQuality.wordRanges[wordIdx];
for (size_t tokenIdx = range.begin; tokenIdx < range.end; ++tokenIdx) {
targetTokenTags[tagOffset + tokenIdx].push_back(wordTag);
}
}
}
}
// Suffix
targetTokenTags.push_back((*spanIt)->tags);
spanIt++;
assert(spanIt == targetTokenSpans.end());
}
// Reports if token `str` is likely to be a continuation of a word. This is used
// to determine whether we should share the markup, or whether we should see
// this token as a fresh start. This implementation will treat "hello[world]"

View File

@ -162,7 +162,7 @@ class HTML {
void restore(Response &response);
private:
using SpanIterator = std::vector<HTML::Span>::const_iterator;
using SpanIterator = std::vector<HTML::Span>::iterator;
using AnnotatedText = marian::bergamot::AnnotatedText;
/// Reconstructs HTML in `response.source` (passed as `in`) and makes a list
@ -175,7 +175,8 @@ class HTML {
/// Inserts the HTML into `response.target` (passed as `in`) based on
/// `targetTokenSpans`, which points to a `Span` for each token (subword) in
/// `response.target`.
AnnotatedText restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans);
AnnotatedText restoreTarget(AnnotatedText const &in, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> const &targetTokenTags);
/// Utilities to test whether subword `str` is part of a word together with
/// the subword `prev`, or a separate word. Basically *does `str` start with
@ -190,6 +191,9 @@ class HTML {
std::vector<HTML::SpanIterator> const &sourceTokenSpans,
std::vector<HTML::SpanIterator> &targetTokenSpans);
void annotateTagStack(Response const &response, std::vector<SpanIterator> const &targetTokenSpans,
std::vector<HTML::TagStack> &targetTokenTags);
/// Turns the alignment scores in `response.alignments` into one source token
/// per target token. Has some heuristics to keep all target tokens of a
/// single word pointing to the same span, and prefers spans with more markup

View File

@ -27,7 +27,7 @@ Response::SentenceQualityScore UnsupervisedQualityEstimator::computeSentenceScor
const float sentenceScore =
std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size();
return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore};
return {wordScores, wordIndices, sentenceScore};
}
LogisticRegressorQualityEstimator::Matrix::Matrix(const size_t rowsParam, const size_t colsParam)
@ -160,7 +160,7 @@ Response::SentenceQualityScore LogisticRegressorQualityEstimator::computeSentenc
const float sentenceScore =
std::accumulate(std::begin(wordScores), std::end(wordScores), float(0.0)) / wordScores.size();
return {wordScores, subwordToWords(wordIndices, target, sentenceIdx), sentenceScore};
return {wordScores, wordIndices, sentenceScore};
}
std::vector<float> LogisticRegressorQualityEstimator::predict(const Matrix& features) const {
@ -267,22 +267,4 @@ std::vector<SubwordRange> mapWords(const std::vector<float>& logProbs, const Ann
return wordIndices;
}
std::vector<ByteRange> subwordToWords(const std::vector<SubwordRange>& wordIndices, const AnnotatedText& target,
const size_t sentenceIdx) {
std::vector<ByteRange> words;
for (const SubwordRange& wordIndice : wordIndices) {
size_t wordBegin = target.wordAsByteRange(sentenceIdx, wordIndice.begin).begin;
size_t wordEnd = target.wordAsByteRange(sentenceIdx, wordIndice.end).begin;
if (isspace(target.text.at(wordBegin))) {
++wordBegin;
}
words.emplace_back(ByteRange{wordBegin, wordEnd});
}
return words;
}
} // namespace marian::bergamot

View File

@ -21,8 +21,6 @@ class QualityEstimator {
virtual void computeQualityScores(const Histories &histories, Response &response) const = 0;
};
using SubwordRange = ByteRange;
/// Unsupervised Quality Estimator model. It uses the translator model's log probabilities (log probs) as a proxy for
/// quality scores. Then, for a given word, its quality score is computed by taking the mean of the log probs of the
/// tokens that make it up. The sentence score is the mean of all word's log probs.
@ -209,14 +207,4 @@ inline std::shared_ptr<QualityEstimator> createQualityEstimator(const AlignedMem
std::vector<SubwordRange> mapWords(const std::vector<float> &logProbs, const AnnotatedText &target,
const size_t sentenceIdx);
/// Given a vector of subwordRanges, it maps the elements to be real words rather than sublevel tokens. The words are
/// represented through ByteRanges.
/// @param [in] wordIndices: A vector where each element correspond to the index of a real word and its values are
/// represented by the SubwordRanges (which are aliases of ByteRanges) which represents sublevel token positions
/// @param [in] target: AnnotatedText target value
/// @param [in] sentenceIdx: the id of a candidate sentence
std::vector<ByteRange> subwordToWords(const std::vector<SubwordRange> &wordIndices, const AnnotatedText &target,
const size_t sentenceIdx);
} // namespace marian::bergamot

View File

@ -142,4 +142,22 @@ std::vector<Alignment> remapAlignments(const Response &first, const Response &se
return alignments;
}
std::vector<ByteRange> getWordByteRanges(const Response &response, size_t sentenceIdx) {
std::vector<ByteRange> wordByteRanges;
wordByteRanges.reserve(response.qualityScores[sentenceIdx].wordRanges.size());
for (auto &&word : response.qualityScores[sentenceIdx].wordRanges) {
size_t wordBegin = response.target.wordAsByteRange(sentenceIdx, word.begin).begin;
size_t wordEnd = response.target.wordAsByteRange(sentenceIdx, word.end).begin;
if (std::isspace(response.target.text.at(wordBegin))) {
++wordBegin;
}
wordByteRanges.emplace_back(ByteRange{wordBegin, wordEnd});
}
return wordByteRanges;
}
} // namespace marian::bergamot

View File

@ -30,8 +30,8 @@ struct Response {
struct SentenceQualityScore {
/// Quality score of each translated word
std::vector<float> wordScores;
/// Each word position in the translated text
std::vector<ByteRange> wordByteRanges;
/// Position of start and end token of each word in the translated text
std::vector<SubwordRange> wordRanges;
/// Whole sentence quality score (it is composed by the mean of its words)
float sentenceScore = 0.0;
};
@ -77,6 +77,8 @@ struct Response {
std::vector<Alignment> remapAlignments(const Response &first, const Response &second);
std::vector<ByteRange> getWordByteRanges(Response const &response, size_t sentenceIdx);
} // namespace bergamot
} // namespace marian

View File

@ -10,7 +10,6 @@
#include "response.h"
using Response = marian::bergamot::Response;
using SentenceQualityScore = marian::bergamot::Response::SentenceQualityScore;
using ByteRange = marian::bergamot::ByteRange;
using namespace emscripten;
@ -20,25 +19,14 @@ EMSCRIPTEN_BINDINGS(byte_range) {
value_object<ByteRange>("ByteRange").field("begin", &ByteRange::begin).field("end", &ByteRange::end);
}
std::vector<SentenceQualityScore> getQualityScores(const Response& response) { return response.qualityScores; }
EMSCRIPTEN_BINDINGS(response) {
class_<Response>("Response")
.constructor<>()
.function("size", &Response::size)
.function("getQualityScores", &getQualityScores)
.function("getOriginalText", &Response::getOriginalText)
.function("getTranslatedText", &Response::getTranslatedText)
.function("getSourceSentence", &Response::getSourceSentenceAsByteRange)
.function("getTranslatedSentence", &Response::getTargetSentenceAsByteRange);
value_object<SentenceQualityScore>("SentenceQualityScore")
.field("wordScores", &SentenceQualityScore::wordScores)
.field("wordByteRanges", &SentenceQualityScore::wordByteRanges)
.field("sentenceScore", &SentenceQualityScore::sentenceScore);
register_vector<Response>("VectorResponse");
register_vector<SentenceQualityScore>("VectorSentenceQualityScore");
register_vector<float>("VectorFloat");
register_vector<ByteRange>("VectorByteRange");
}

View File

@ -73,7 +73,7 @@ label {
align-self: center;
}
textarea {
textarea, .output-area {
padding: 1rem;
font-family: sans-serif;
font-size: 1rem;
@ -97,3 +97,22 @@ button:hover {
#output {
background-color: #f4f4f4;
}
.output-area [x-bergamot-word-score].bad {
background-image:
linear-gradient(45deg, transparent 65%, red 80%, transparent 90%),
linear-gradient(135deg, transparent 5%, red 15%, transparent 25%),
linear-gradient(135deg, transparent 45%, red 55%, transparent 65%),
linear-gradient(45deg, transparent 25%, red 35%, transparent 50%);
background-repeat:repeat-x;
background-size: 8px 2px;
background-position:0 95%;
}
.output-area [x-bergamot-sentence-score].bad {
background: rgba(255, 128, 128, 0.8);
}
.output-area [x-bergamot-sentence-index].highlight-sentence {
background: rgba(255, 255, 128, 0.8);
}

View File

@ -24,7 +24,7 @@
To
<select id="lang-to" name="to" class="lang-select"></select>
</label>
<textarea id="output" name="output" readonly></textarea>
<div id="output" class="output-area"></div>
</div>
<div class="footer" id="status"></div>
</div>

View File

@ -38,22 +38,58 @@ const _prepareTranslateOptions = (paragraphs) => {
return translateOptions;
};
const textToHTML = (text) => {
const div = document.createElement('div');
div.appendChild(document.createTextNode(text));
return div.innerHTML;
};
const translateCall = () => {
const text = document.querySelector("#input").value + " ";
const text = document.querySelector("#input").value;
if (!text.trim().length) return;
const paragraphs = text.split("\n");
const paragraphs = text.split(/\n+/).map(textToHTML); // escape HTML
const translateOptions = _prepareTranslateOptions(paragraphs);
$("#output").setAttribute("disabled", true);
const lngFrom = langFrom.value;
const lngTo = langTo.value;
worker.postMessage(["translate", lngFrom, lngTo, paragraphs, translateOptions]);
};
const addQualityClasses = (root) => {
// You can do this wit CSS variables, calc() and min/max, but JS is just easier
root.querySelectorAll('[x-bergamot-sentence-score]').forEach(el => {
// Note: these thresholds are just examples, they are not good thresholds!
el.classList.toggle('bad', parseFloat(el.getAttribute('x-bergamot-sentence-score')) > -0.1);
});
root.querySelectorAll('[x-bergamot-word-score]').forEach(el => {
// Note: these thresholds are just examples, they are not good thresholds!
el.classList.toggle('bad', parseFloat(el.getAttribute('x-bergamot-word-score')) > -0.1);
});
// Add tooltips to each (sub)word with sentence and word score.
root.querySelectorAll('[x-bergamot-sentence-score] > [x-bergamot-word-score]').forEach(el => {
const sentenceScore = parseFloat(el.parentNode.getAttribute('x-bergamot-sentence-score'));
const wordScore = parseFloat(el.getAttribute('x-bergamot-word-score'));
el.title = `Sentence: ${sentenceScore} Word: ${wordScore}`;
});
}
worker.onmessage = function (e) {
if (e.data[0] === "translate_reply" && e.data[1]) {
document.querySelector("#output").value = e.data[1].join("\n\n");
$("#output").removeAttribute("disabled");
// Clear output of previous translation
document.querySelector("#output").innerHTML = '';
// Add each translation in its own div to have a known root in which the
// sentence ids are unique. Used for highlighting sentences.
e.data[1].forEach(translatedHTML => {
const translation = document.createElement('div');
translation.classList.add('translation');
translation.innerHTML = translatedHTML;
addQualityClasses(translation);
document.querySelector("#output").appendChild(translation);
});
} else if (e.data[0] === "load_model_reply" && e.data[1]) {
status(e.data[1]);
translateCall();
@ -76,8 +112,8 @@ const loadModel = () => {
console.log(`Loading model '${lngFrom}${lngTo}'`);
worker.postMessage(["load_model", lngFrom, lngTo]);
} else {
const input = document.querySelector("#input").value;
document.querySelector("#output").value = input;
const input = textToHTML(document.querySelector("#input").value);
document.querySelector("#output").innerHTML = input;
}
};
@ -95,6 +131,14 @@ $(".swap").addEventListener("click", e => {
loadModel();
});
$('#output').addEventListener('mouseover', e => {
const root = e.target.closest('.translation');
const sentence = e.target.parentNode.hasAttribute('x-bergamot-sentence-index') ? e.target.parentNode.getAttribute('x-bergamot-sentence-index') : null;
document.querySelectorAll('#output font[x-bergamot-sentence-index]').forEach(el => {
el.classList.toggle('highlight-sentence', el.getAttribute('x-bergamot-sentence-index') === sentence && el.closest('.translation') === root);
})
})
function init() {
// try to guess input language from user agent
let myLang = navigator.language;

View File

@ -137,13 +137,11 @@ const translate = (from, to, input, translateOptions) => {
const listSourceText = _parseSourceText(vectorResponse);
const listTranslatedTextSentences = _parseTranslatedTextSentences(vectorResponse);
const listSourceTextSentences = _parseSourceTextSentences(vectorResponse);
const listTranslatedTextSentenceQualityScores = _parseTranslatedTextSentenceQualityScores(vectorResponse);
log(`Source text: ${listSourceText}`);
log(`Translated text: ${listTranslatedText}`);
log(`Translated sentences: ${JSON.stringify(listTranslatedTextSentences)}`);
log(`Source sentences: ${JSON.stringify(listSourceTextSentences)}`);
log(`Translated sentence quality scores: ${JSON.stringify(listTranslatedTextSentenceQualityScores)}`);
return listTranslatedText;
} finally {
@ -292,44 +290,6 @@ const _parseSourceTextSentences = (vectorResponse) => {
return result;
}
const _parseTranslatedTextSentenceQualityScores = (vectorResponse) => {
const result = [];
for (let i = 0; i < vectorResponse.size(); i++) {
const response = vectorResponse.get(i);
const translatedText = response.getTranslatedText();
const vectorSentenceQualityScore = response.getQualityScores();
log(`No. of sentences: "${vectorSentenceQualityScore.size()}"`);
const sentenceQualityScores = [];
for (let sentenceIndex=0; sentenceIndex < vectorSentenceQualityScore.size(); sentenceIndex++) {
const sentenceQualityScoreObject = vectorSentenceQualityScore.get(sentenceIndex);
const wordByteRangeList = [];
const wordList = [];
const wordScoreList = [];
const vectorWordScore = sentenceQualityScoreObject.wordScores;
const vectorWordByteRange = sentenceQualityScoreObject.wordByteRanges;
for (let wordIndex = 0; wordIndex < vectorWordScore.size(); wordIndex++) {
const wordScore = vectorWordScore.get(wordIndex);
const wordByteRange = vectorWordByteRange.get(wordIndex);
wordScoreList.push(wordScore);
wordByteRangeList.push(wordByteRange);
const word = _getSubString(translatedText, wordByteRange);
wordList.push(word);
}
const sentenceQualityScore = {
wordByteRanges: wordByteRangeList,
words: wordList,
wordScores: wordScoreList,
sentenceScore: sentenceQualityScoreObject.sentenceScore
};
sentenceQualityScores.push(sentenceQualityScore);
}
result.push(sentenceQualityScores);
}
return result;
}
const _prepareResponseOptions = (translateOptions) => {
let vectorResponseOptions = new Module.VectorResponseOptions;
translateOptions.forEach(translateOption => {