Merged PR 18185: Support for Microsoft legacy binary shortlist

Adds support for Microsoft-internal binary shortlist format.
This commit is contained in:
Martin Junczys-Dowmunt 2021-03-18 03:33:13 +00:00
parent a1aaa32c6a
commit e08c52a8df
13 changed files with 1429 additions and 4 deletions

View File

@ -40,6 +40,7 @@ set(MARIAN_SOURCES
data/corpus_sqlite.cpp
data/corpus_nbest.cpp
data/text_input.cpp
data/shortlist.cpp
3rd_party/cnpy/cnpy.cpp
3rd_party/ExceptionWithCallStack.cpp
@ -107,10 +108,15 @@ set(MARIAN_SOURCES
training/validator.cpp
training/communicator.cpp
# this is only compiled to catch build errors, but not linked
# this is only compiled to catch build errors
microsoft/quicksand.cpp
microsoft/cosmos.cpp
# copied from quicksand to be able to read binary shortlist
microsoft/shortlist/utils/Converter.cpp
microsoft/shortlist/utils/StringUtils.cpp
microsoft/shortlist/utils/ParameterTree.cpp
$<TARGET_OBJECTS:libyaml-cpp>
$<TARGET_OBJECTS:SQLiteCpp>
$<TARGET_OBJECTS:pathie-cpp>

View File

@ -546,7 +546,6 @@ void FactoredVocab::constructNormalizationInfoForVocab() {
/*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const {
for (; num-- > 0; ptr++) {
auto word = Word::fromWordIndex(*ptr);
auto wordString = word2string(word);
auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first;
*ptr = (WordIndex)lemmaIndex;
}

153
src/data/shortlist.cpp Normal file
View File

@ -0,0 +1,153 @@
#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
namespace marian {
namespace data {
// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, size_t num = 1) {
const T* ptr = (const T*)current;
current = (const T*)current + num;
return ptr;
}
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx,
size_t /*trgIdx*/,
bool /*shared*/)
: options_(options),
srcVocab_(srcVocab),
trgVocab_(trgVocab),
srcIdx_(srcIdx) {
std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to filter path given");
std::string fname = vals[0];
auto firstNum = vals.size() > 1 ? std::stoi(vals[1]) : 0;
auto bestNum = vals.size() > 2 ? std::stoi(vals[2]) : 0;
float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
if(firstNum != 0 || bestNum != 0 || threshold != 0) {
LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored.");
}
mmap_ = mio::mmap_source(fname); // memory-map the binary file once
const void* current = mmap_.data(); // pointer iterator over binary file
// compare magic number in binary file to make sure we are reading the right thing
const int32_t MAGIC_NUMBER = 1234567890;
int32_t header_magic_number = *get<int32_t>(current);
ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number");
auto config = ::quicksand::ParameterTree::FromBinaryReader(current);
use16bit_ = config->GetBoolReq("use_16_bit");
LOG(info, "[data] Mapping Quicksand shortlist from {}", fname);
idSize_ = sizeof(int32_t);
if (use16bit_) {
idSize_ = sizeof(uint16_t);
}
// mmap the binary shortlist pieces
numDefaultIds_ = *get<int32_t>(current);
defaultIds_ = get<int32_t>(current, numDefaultIds_);
numSourceIds_ = *get<int32_t>(current);
sourceLengths_ = get<int32_t>(current, numSourceIds_);
sourceOffsets_ = get<int32_t>(current, numSourceIds_);
numShortlistIds_ = *get<int32_t>(current);
sourceToShortlistIds_ = get<uint8_t>(current, idSize_ * numShortlistIds_);
// display parameters
LOG(info,
"[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids",
numSourceIds_,
numDefaultIds_,
numShortlistIds_);
}
Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
auto srcBatch = (*batch)[srcIdx_];
auto maxShortlistSize = trgVocab_->size();
std::unordered_set<int32_t> indexSet;
for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) {
int32_t id = defaultIds_[i];
indexSet.insert(id);
}
// State
std::vector<std::pair<const uint8_t*, int32_t>> curShortlists(maxShortlistSize);
auto curShortlistIt = curShortlists.begin();
// Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank.
// E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ...
int32_t maxLength = 0;
for (Word word : srcBatch->data()) {
int32_t sourceId = (int32_t)word.toWordIndex();
srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1);
if (sourceId < numSourceIds_) { // if it's a valid source id
const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping
int32_t length = sourceLengths_[sourceId]; // how many mappings are there
curShortlistIt->first = curShortlistIds;
curShortlistIt->second = length;
curShortlistIt++;
if (length > maxLength)
maxLength = length;
}
}
// collect the actual shortlist mappings
for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) {
for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) {
int32_t length = curShortlists[j].second;
if (i < length) {
const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first;
int32_t id = 0;
if (use16bit_) {
const uint16_t* source_shortlist_ids = reinterpret_cast<const uint16_t*>(source_shortlist_ids_bytes);
id = (int32_t)source_shortlist_ids[i];
}
else {
const int32_t* source_shortlist_ids = reinterpret_cast<const int32_t*>(source_shortlist_ids_bytes);
id = source_shortlist_ids[i];
}
indexSet.insert(id);
}
}
}
// turn into vector and sort (selected indices)
std::vector<WordIndex> indices;
indices.reserve(indexSet.size());
for(auto i : indexSet)
indices.push_back((WordIndex)i);
std::sort(indices.begin(), indices.end());
return New<Shortlist>(indices);
}
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx,
size_t trgIdx,
bool shared) {
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
}
}
} // namespace data
} // namespace marian

View File

@ -5,6 +5,7 @@
#include "common/file_stream.h"
#include "data/corpus_base.h"
#include "data/types.h"
#include "mio/mio.hpp"
#include <random>
#include <unordered_map>
@ -292,5 +293,51 @@ public:
}
};
/*
Legacy binary shortlist for Microsoft-internal use.
*/
class QuicksandShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
Ptr<const Vocab> srcVocab_;
Ptr<const Vocab> trgVocab_;
size_t srcIdx_;
mio::mmap_source mmap_;
// all the quicksand bits go here
bool use16bit_{false};
int32_t numDefaultIds_;
int32_t idSize_;
const int32_t* defaultIds_{nullptr};
int32_t numSourceIds_{0};
const int32_t* sourceLengths_{nullptr};
const int32_t* sourceOffsets_{nullptr};
int32_t numShortlistIds_{0};
const uint8_t* sourceToShortlistIds_{nullptr};
public:
QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false);
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};
/*
Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
*/
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false);
} // namespace data
} // namespace marian

View File

@ -0,0 +1,25 @@
#pragma once
// Do NOT include this file directly except in special circumstances.
// (E.g., you want to define macros which call these but don't want to include Logger.h everywhere).
// Normally you should include logging/Logger.h
#define LOG_WRITE(format, ...) do {\
abort(); \
} while (0)
#define LOG_WRITE_STRING(str) do {\
abort(); \
} while (0)
#define LOG_ERROR(format, ...) do {\
abort(); \
} while (0)
#define LOG_ERROR_AND_THROW(format, ...) do {\
abort(); \
} while (0)
#define DECODING_LOGIC_ERROR(format, ...) do {\
abort(); \
} while (0)

View File

@ -0,0 +1,59 @@
#include "microsoft/shortlist/utils/Converter.h"
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
int64_t Converter::ToInt64(const std::string& str) {
return ConvertSingleInternal<int64_t>(str, "int64_t");
}
uint64_t Converter::ToUInt64(const std::string& str) {
return ConvertSingleInternal<uint64_t>(str, "int64_t");
}
int32_t Converter::ToInt32(const std::string& str) {
return ConvertSingleInternal<int32_t>(str, "int32_t");
}
float Converter::ToFloat(const std::string& str) {
// In case the value is out of range of a 32-bit float, but in range of a 64-bit double,
// it's better to convert as a double and then do the conersion.
return (float)ConvertSingleInternal<double>(str, "float");
}
double Converter::ToDouble(const std::string& str) {
return ConvertSingleInternal<double>(str, "double");
}
bool Converter::ToBool(const std::string& str) {
bool value = false;
if (!TryConvert(str, /* out */ value)) {
LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type 'bool'", str.c_str());
}
return value;
}
std::vector<int32_t> Converter::ToInt32Vector(const std::vector<std::string>& items) {
return ConvertVectorInternal<int32_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int32_t");
}
std::vector<int64_t> Converter::ToInt64Vector(const std::vector<std::string>& items) {
return ConvertVectorInternal<int64_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int64_t");
}
std::vector<float> Converter::ToFloatVector(const std::vector<std::string>& items) {
return ConvertVectorInternal<float, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "float");
}
std::vector<double> Converter::ToDoubleVector(const std::vector<std::string>& items) {
return ConvertVectorInternal<double, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "double");
}
void Converter::HandleConversionError(const std::string& str, const char * type_name) {
str; type_name; // make compiler happy
LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type '%s'", str.c_str(), type_name);
}
} // namespace quicksand

View File

@ -0,0 +1,83 @@
#pragma once
#include <stdint.h>
#include <string>
#include <vector>
#include <sstream>
namespace quicksand {
class Converter {
public:
static int32_t ToInt32(const std::string& str);
static int64_t ToInt64(const std::string& str);
static uint64_t ToUInt64(const std::string& str);
static float ToFloat(const std::string& str);
static double ToDouble(const std::string& str);
static bool ToBool(const std::string& str);
static std::vector<int32_t> ToInt32Vector(const std::vector<std::string>& items);
static std::vector<int64_t> ToInt64Vector(const std::vector<std::string>& items);
static std::vector<float> ToFloatVector(const std::vector<std::string>& items);
static std::vector<double> ToDoubleVector(const std::vector<std::string>& items);
static bool TryConvert(const std::string& str, /* out*/ bool& obj) {
if (str == "True" || str == "true" || str == "TRUE" || str == "Yes" || str == "yes" || str == "1") {
obj = true;
return true;
}
else if (str == "False" || str == "false" || str == "FALSE" || str == "No" || str == "no" || str == "0") {
obj = false;
return true;
}
return false;
}
template <typename T>
static bool TryConvert(const std::string& str, /* out*/ T& value) {
std::istringstream ss(str);
value = T();
if (!(ss >> value)) {
return false;
}
return true;
}
private:
template <typename T>
static T ConvertSingleInternal(const std::string& str, const char * type_name);
template <typename T, typename I>
static std::vector<T> ConvertVectorInternal(I begin, I end, const char * type_name);
static void HandleConversionError(const std::string& str, const char * type_name);
};
template <typename T>
T Converter::ConvertSingleInternal(const std::string& str, const char * type_name) {
std::istringstream ss(str);
T value = T();
if (!(ss >> value)) {
HandleConversionError(str, type_name);
}
return value;
}
template <typename T, typename I>
std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * type_name) {
std::vector<T> items;
for (I it = begin; it != end; it++) {
items.push_back(ConvertSingleInternal<T>(*it, type_name));
}
return items;
}
} // namespace quicksand

View File

@ -0,0 +1,417 @@
#include "microsoft/shortlist/utils/ParameterTree.h"
#include <string>
#include "microsoft/shortlist/utils/StringUtils.h"
#include "microsoft/shortlist/utils/Converter.h"
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
std::shared_ptr<ParameterTree> ParameterTree::m_empty_tree = std::make_shared<ParameterTree>("params");
ParameterTree::ParameterTree() {
m_name = "root";
}
ParameterTree::ParameterTree(const std::string& name) {
m_name = name;
}
ParameterTree::~ParameterTree() {
}
void ParameterTree::Clear() {
}
void ParameterTree::ReplaceVariables(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars)
{
ReplaceVariablesInternal(vars, error_on_unknown_vars);
}
void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) {
RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param);
}
void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) {
RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param);
}
void ParameterTree::RegisterFloat(const std::string& name, float * param) {
RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param);
}
void ParameterTree::RegisterDouble(const std::string& name, double * param) {
RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param);
}
void ParameterTree::RegisterBool(const std::string& name, bool * param) {
RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param);
}
void ParameterTree::RegisterString(const std::string& name, std::string * param) {
RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param);
}
std::shared_ptr<ParameterTree> ParameterTree::FromBinaryReader(const void*& current) {
std::shared_ptr<ParameterTree> root = std::make_shared<ParameterTree>();
root->ReadBinary(current);
return root;
}
void ParameterTree::SetRegisteredParams() {
for (std::size_t i = 0; i < m_registered_params.size(); i++) {
const RegisteredParam& rp = m_registered_params[i];
switch (rp.Type()) {
case PARAM_TYPE_INT32:
(*(int32_t *)rp.Data()) = GetInt32Req(rp.Name());
break;
case PARAM_TYPE_INT64:
(*(int64_t *)rp.Data()) = GetInt64Req(rp.Name());
break;
default:
LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type());
}
}
}
int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToInt32(*value);
}
int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToInt64(*value);
}
uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToUInt64(*value);
}
double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToDouble(*value);
}
float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToFloat(*value);
}
std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return (*value);
}
bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return defaultValue;
}
return Converter::ToBool(*value);
}
int32_t ParameterTree::GetInt32Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToInt32(value);
}
uint64_t ParameterTree::GetUInt64Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToUInt64(value);
}
int64_t ParameterTree::GetInt64Req(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToInt64(value);
}
double ParameterTree::GetDoubleReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToDouble(value);
}
float ParameterTree::GetFloatReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToFloat(value);
}
bool ParameterTree::GetBoolReq(const std::string& name) const {
std::string value = GetStringReq(name);
return Converter::ToBool(value);
}
std::string ParameterTree::GetStringReq(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str());
}
return (*value);
}
std::vector<std::string> ParameterTree::GetFileListReq(const std::string& name) const {
std::vector<std::string> output = GetFileListOptional(name);
if (output.size() == 0) {
LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str());
}
return output;
}
std::vector<std::string> ParameterTree::GetFileListOptional(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr || (*value).size() == 0) {
return std::vector<std::string>();
}
std::vector<std::string> all_files = StringUtils::Split(*value, ";");
return all_files;
}
std::vector<std::string> ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const {
std::string value = GetStringReq(name);
std::vector<std::string> output = StringUtils::Split(value, sep);
return output;
}
std::vector<std::string> ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const {
std::string value = GetStringOr(name, "");
std::vector<std::string> output = StringUtils::Split(value, sep);
return output;
}
std::shared_ptr<ParameterTree> ParameterTree::GetChildReq(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return child;
}
}
LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str());
return nullptr; // never happens
}
std::shared_ptr<ParameterTree> ParameterTree::GetChildOrEmpty(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return child;
}
}
return std::make_shared<ParameterTree>();
}
// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, size_t num = 1) {
const T* ptr = (const T*)current;
current = (const T*)current + num;
return ptr;
}
void ParameterTree::ReadBinary(const void*& current) {
auto nameLength = *get<int32_t>(current);
auto nameBytes = get<char>(current, nameLength);
m_name = std::string(nameBytes, nameBytes + nameLength);
auto textLength = *get<int32_t>(current);
auto textBytes = get<char>(current, textLength);
m_text = std::string(textBytes, textBytes + textLength);
int32_t num_children = *get<int32_t>(current);
m_children.resize(num_children);
for (int32_t i = 0; i < num_children; i++) {
m_children[i].reset(new ParameterTree());
m_children[i]->ReadBinary(current);
}
}
std::vector< std::shared_ptr<ParameterTree> > ParameterTree::GetChildren(const std::string& name) const {
std::vector< std::shared_ptr<ParameterTree> > children;
for (std::shared_ptr<ParameterTree> child : m_children) {
if (child->Name() == name) {
children.push_back(child);
}
}
return children;
}
void ParameterTree::AddParam(const std::string& name, const std::string& text) {
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
child->SetText(text);
m_children.push_back(child);
}
void ParameterTree::SetParam(const std::string& name, const std::string& text) {
for (const auto& child : m_children) {
if (child->Name() == name) {
child->SetText(text);
return;
}
}
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
child->SetText(text);
m_children.push_back(child);
}
void ParameterTree::AddChild(std::shared_ptr<ParameterTree> child) {
m_children.push_back(child);
}
bool ParameterTree::HasParam(const std::string& name) const {
const std::string * value = GetParamInternal(name);
if (value == nullptr) {
return false;
}
return true;
}
bool ParameterTree::HasChild(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return true;
}
}
return false;
}
std::string ParameterTree::ToString() const {
std::ostringstream ss;
ToStringInternal(0, ss);
return ss.str();
}
const std::string * ParameterTree::GetParamInternal(const std::string& name) const {
for (const auto& child : m_children) {
if (child->Name() == name) {
return &(child->Text());
}
}
return nullptr;
}
void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) {
if (m_registered_param_names.find(name) != m_registered_param_names.end()) {
LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str());
}
m_registered_params.push_back(RegisteredParam(name, type, param));
m_registered_param_names.insert(name);
}
void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const {
for (int32_t i = 0; i < 2*depth; i++) {
ss << " ";
}
ss << "<" << m_name << ">";
if (m_children.size() > 0) {
ss << "\n";
for (const std::shared_ptr<ParameterTree>& child : m_children) {
child->ToStringInternal(depth+1, ss);
}
for (int32_t i = 0; i < 2 * depth; i++) {
ss << " ";
}
ss << "</" << m_name << ">\n";
}
else {
ss << m_text << "</" << m_name << ">\n";
}
}
std::shared_ptr<ParameterTree> ParameterTree::Clone() const {
std::shared_ptr<ParameterTree> node = std::make_shared<ParameterTree>(m_name);
node->m_text = m_text;
for (auto& child : m_children) {
node->m_children.push_back(child->Clone());
}
return node;
}
void ParameterTree::Merge(const ParameterTree& other) {
m_name = other.m_name;
m_text = other.m_text;
for (auto& other_child : other.m_children) {
if (HasChild(other_child->Name())) {
auto my_child = GetChildReq(other_child->Name());
if (other_child->Text() != "" && my_child->Text() != "") {
my_child->SetText(other_child->Text());
}
else {
my_child->Merge(*other_child);
}
}
else {
m_children.push_back(other_child->Clone());
}
}
}
void ParameterTree::ReplaceVariablesInternal(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars)
{
std::size_t offset = 0;
std::ostringstream ss;
while (true) {
std::size_t s_pos = m_text.find("$$", offset);
if (s_pos == std::string::npos) {
break;
}
std::size_t e_pos = m_text.find("$$", s_pos + 2);
if (e_pos == std::string::npos) {
break;
}
if (offset != s_pos) {
ss << m_text.substr(offset, s_pos-offset);
}
std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2));
auto it = vars.find(var_name);
if (it != vars.end()) {
std::string value = it->second;
ss << value;
}
else {
if (error_on_unknown_vars) {
LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str());
}
else {
ss << "$$" << var_name << "$$";
}
}
offset = e_pos + 2;
}
ss << m_text.substr(offset);
m_text = ss.str();
for (auto& child : m_children) {
child->ReplaceVariablesInternal(vars, error_on_unknown_vars);
}
}
} // namespace quicksand

View File

@ -0,0 +1,185 @@
#pragma once
#include <string>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include "microsoft/shortlist/utils/StringUtils.h"
namespace quicksand {
class ParameterTree {
private:
enum ParameterType {
PARAM_TYPE_INT32,
PARAM_TYPE_INT64,
PARAM_TYPE_UINT64,
PARAM_TYPE_FLOAT,
PARAM_TYPE_DOUBLE,
PARAM_TYPE_BOOL,
PARAM_TYPE_STRING
};
class RegisteredParam {
private:
std::string m_name;
ParameterType m_type;
void * m_data;
public:
RegisteredParam() {}
RegisteredParam(const std::string& name,
ParameterType type,
void * data)
{
m_name = name;
m_type = type;
m_data = data;
}
const std::string& Name() const {return m_name;}
const ParameterType& Type() const {return m_type;}
void * Data() const {return m_data;}
};
static std::shared_ptr<ParameterTree> m_empty_tree;
std::string m_name;
std::string m_text;
std::vector< std::shared_ptr<ParameterTree> > m_children;
std::unordered_set<std::string> m_registered_param_names;
std::vector<RegisteredParam> m_registered_params;
public:
ParameterTree();
ParameterTree(const std::string& name);
~ParameterTree();
inline const std::string& Text() const { return m_text; }
inline void SetText(const std::string& text) { m_text = text; }
inline const std::string& Name() const { return m_name; }
inline void SetName(const std::string& name) { m_name = name; }
void Clear();
void ReplaceVariables(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars = true);
void RegisterInt32(const std::string& name, int32_t * param);
void RegisterInt64(const std::string& name, int64_t * param);
void RegisterFloat(const std::string& name, float * param);
void RegisterDouble(const std::string& name, double * param);
void RegisterBool(const std::string& name, bool * param);
void RegisterString(const std::string& name, std::string * param);
static std::shared_ptr<ParameterTree> FromBinaryReader(const void*& current);
void SetRegisteredParams();
int32_t GetInt32Req(const std::string& name) const;
int64_t GetInt64Req(const std::string& name) const;
uint64_t GetUInt64Req(const std::string& name) const;
double GetDoubleReq(const std::string& name) const;
float GetFloatReq(const std::string& name) const;
std::string GetStringReq(const std::string& name) const;
bool GetBoolReq(const std::string& name) const;
int32_t GetInt32Or(const std::string& name, int32_t defaultValue) const;
int64_t GetInt64Or(const std::string& name, int64_t defaultValue) const;
uint64_t GetUInt64Or(const std::string& name, uint64_t defaultValue) const;
std::string GetStringOr(const std::string& name, const std::string& defaultValue) const;
double GetDoubleOr(const std::string& name, double defaultValue) const;
float GetFloatOr(const std::string& name, float defaultValue) const;
bool GetBoolOr(const std::string& name, bool defaultValue) const;
std::vector<std::string> GetFileListReq(const std::string& name) const;
std::vector<std::string> GetFileListOptional(const std::string& name) const;
std::vector<std::string> GetStringListReq(const std::string& name, const std::string& sep = " ") const;
std::vector<std::string> GetStringListOptional(const std::string& name, const std::string& sep = " ") const;
std::shared_ptr<ParameterTree> GetChildReq(const std::string& name) const;
std::shared_ptr<ParameterTree> GetChildOrEmpty(const std::string& name) const;
std::vector< std::shared_ptr<ParameterTree> > GetChildren(const std::string& name) const;
inline const std::vector< std::shared_ptr<ParameterTree> >& GetChildren() const { return m_children; }
void ReadBinary(const void*& current);
void AddParam(const std::string& name, const std::string& text);
template <typename T>
void AddParam(const std::string& name, const T& obj);
void SetParam(const std::string& name, const std::string& text);
template <typename T>
void SetParam(const std::string& name, const T& obj);
void AddChild(std::shared_ptr<ParameterTree> child);
std::string ToString() const;
bool HasChild(const std::string& name) const;
bool HasParam(const std::string& name) const;
std::shared_ptr<ParameterTree> Clone() const;
void Merge(const ParameterTree& other);
private:
void ReplaceVariablesInternal(
const std::unordered_map<std::string, std::string>& vars,
bool error_on_unknown_vars);
void RegisterItemInternal(const std::string& name, ParameterType type, void * param);
const std::string * GetParamInternal(const std::string& name) const;
void ToStringInternal(int32_t depth, std::ostream& ss) const;
};
template <typename T>
void ParameterTree::AddParam(const std::string& name, const T& obj) {
AddParam(name, StringUtils::ToString(obj));
}
template <typename T>
void ParameterTree::SetParam(const std::string& name, const T& obj) {
SetParam(name, StringUtils::ToString(obj));
}
} // namespace quicksand

View File

@ -0,0 +1,16 @@
#pragma once
#include <inttypes.h>
#ifdef QUICKSAND_WINDOWS_BUILD
#define PI32 "d"
#define PI64 "lld"
#define PU32 "u"
#define PU64 "llu"
#else
#define PI32 PRId32
#define PI64 PRId64
#define PU32 PRIu32
#define PU64 PRIu64
#endif

View File

@ -0,0 +1,338 @@
#include "microsoft/shortlist/utils/StringUtils.h"
#include <stdio.h>
#include <algorithm>
#include <string>
namespace quicksand {
#include "microsoft/shortlist/logging/LoggerMacros.h"
std::string StringUtils::VarArgsToString(const char * format, va_list args) {
if (format == nullptr) {
LOG_ERROR_AND_THROW("'format' cannot be null in StringUtils::VarArgsToString");
}
std::string output;
// Most of the time the stack buffer (5000 chars) will be sufficient.
// In cases where this is insufficient, dynamically allocate an appropriately sized buffer
char buffer[5000];
#ifdef QUICKSAND_WINDOWS_BUILD
va_list copy;
va_copy(copy, args);
int ret = vsnprintf_s(buffer, sizeof(buffer), _TRUNCATE, format, copy);
va_end(copy);
if (ret >= 0) {
output = std::string(buffer, buffer + ret);
}
else {
va_list copy2;
va_copy(copy2, args);
int needed_size = _vscprintf(format, copy2);
va_end(copy2);
if (needed_size < 0) {
LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen");
}
char * dynamic_buffer = new char[needed_size+1];
int ret2 = vsnprintf_s(dynamic_buffer, needed_size+1, _TRUNCATE, format, args);
if (ret2 >= 0) {
output = std::string(dynamic_buffer, dynamic_buffer + ret2);
delete[] dynamic_buffer;
}
else {
output = "";
delete[] dynamic_buffer;
LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen, "
"since we made a call to _vscprintf() to check the dynamic buffer size. The call to _vscprintf() "
"returned %d bytes, but apparently that was not enough. This would imply a bug in MSVC's vsnprintf_s implementation.", needed_size);
}
}
#else
va_list copy;
va_copy(copy, args);
int needed_size = vsnprintf(buffer, sizeof(buffer), format, copy);
va_end(copy);
if (needed_size < (int)sizeof(buffer)) {
output = std::string(buffer, buffer + needed_size);
}
else {
char * dynamic_buffer = new char[needed_size+1];
int ret = vsnprintf(dynamic_buffer, needed_size + 1, format, args);
if (ret >= 0 && ret < needed_size + 1) {
output = std::string(dynamic_buffer);
delete[] dynamic_buffer;
}
else {
output = "";
delete[] dynamic_buffer;
LOG_ERROR_AND_THROW("A call to vsnprintf() failed. Return value: %d.",
ret);
}
}
#endif
return output;
}
std::vector<std::string> StringUtils::SplitIntoLines(const std::string& input) {
std::vector<std::string> output;
if (input.size() == 0) {
return output;
}
std::size_t start = 0;
for (std::size_t i = 0; i < input.size(); i++) {
char c = input[i];
if (c == '\r' || c == '\n') {
output.push_back(std::string(input.begin() + start, input.begin() + i));
start = i+1;
}
if (c == '\r' && i + 1 < input.size() && input[i+1] == '\n') {
i++;
start = i+1;
}
}
// do NOT put an empty length trailing line (but empty length intermediate lines are fine)
if (input.begin() + start != input.end()) {
output.push_back(std::string(input.begin() + start, input.end()));
}
return output;
}
bool StringUtils::StartsWith(const std::string& str, const std::string& prefix) {
if (str.length() < prefix.length())
return false;
return std::equal(prefix.begin(), prefix.end(), str.begin());
}
bool StringUtils::EndsWith(const std::string& str, const std::string& suffix) {
if (str.length() < suffix.length())
return false;
return std::equal(suffix.begin(), suffix.end(), str.end() - suffix.length());
}
std::vector<std::string> StringUtils::SplitFileList(const std::string& input) {
std::vector<std::string> output;
for (const std::string& s : SplitIntoLines(input)) {
for (const std::string& t : Split(s, ";")) {
std::string f = CleanupWhitespace(t);
output.push_back(f);
}
}
return output;
}
std::vector<std::string> StringUtils::Split(const std::string& input, char splitter) {
std::vector<std::string> output;
if (input.size() == 0) {
return output;
}
std::size_t start = 0;
for (std::size_t i = 0; i < input.size(); i++) {
if (input[i] == splitter) {
output.push_back(std::string(input.begin() + start, input.begin() + i));
start = i+1;
}
}
output.push_back(std::string(input.begin() + start, input.end()));
return output;
}
std::vector<std::string> StringUtils::Split(const std::string& input, const std::string& splitter) {
std::vector<std::string> output;
if (input.size() == 0) {
return output;
}
std::size_t pos = 0;
while (true) {
std::size_t next_pos = input.find(splitter, pos);
if (next_pos == std::string::npos) {
output.push_back(std::string(input.begin() + pos, input.end()));
break;
}
else {
output.push_back(std::string(input.begin() + pos, input.begin() + next_pos));
}
pos = next_pos + splitter.size();
}
return output;
}
std::string StringUtils::Join(const std::string& joiner, const uint8_t * items, int32_t length) {
std::ostringstream ss;
for (int32_t i = 0; i < length; i++) {
if (i != 0) {
ss << joiner;
}
ss << (int32_t)(items[i]);
}
return ss.str();
}
std::string StringUtils::Join(const std::string& joiner, const int8_t * items, int32_t length) {
std::ostringstream ss;
for (int32_t i = 0; i < length; i++) {
if (i != 0) {
ss << joiner;
}
ss << (int32_t)(items[i]);
}
return ss.str();
}
std::string StringUtils::PrintString(const char * format, ...) {
va_list args;
va_start(args, format);
std::string output = StringUtils::VarArgsToString(format, args);
va_end(args);
return output;
}
std::vector<std::string> StringUtils::WhitespaceTokenize(const std::string& input) {
std::vector<std::string> output;
if (input.size() == 0) {
return output;
}
std::size_t size = input.size();
std::size_t start = 0;
std::size_t end = size;
for (std::size_t i = 0; i < size; i++) {
char c = input[i];
if (IsWhitespace(c)) {
start++;
}
else {
break;
}
}
for (std::size_t i = 0; i < size; i++) {
char c = input[size-1-i];
if (IsWhitespace(c)) {
end--;
}
else {
break;
}
}
if (end <= start) {
return output;
}
bool prev_is_whitespace = false;
std::size_t token_start = start;
for (std::size_t i = start; i < end; i++) {
char c = input[i];
if (IsWhitespace(c)) {
if (!prev_is_whitespace) {
output.push_back(std::string(input.begin() + token_start, input.begin() + i));
}
prev_is_whitespace = true;
token_start = i+1;
}
else {
prev_is_whitespace = false;
}
}
output.push_back(std::string(input.begin() + token_start, input.begin() + end));
return output;
}
std::string StringUtils::CleanupWhitespace(const std::string& input) {
if (input.size() == 0) {
return std::string("");
}
std::size_t size = input.size();
std::size_t start = 0;
std::size_t end = size;
for (std::size_t i = 0; i < size; i++) {
char c = input[i];
if (IsWhitespace(c)) {
start++;
}
else {
break;
}
}
for (std::size_t i = 0; i < size; i++) {
char c = input[size-1-i];
if (IsWhitespace(c)) {
end--;
}
else {
break;
}
}
if (end <= start) {
return std::string("");
}
std::ostringstream ss;
bool prev_is_whitespace = false;
for (std::size_t i = start; i < end; i++) {
char c = input[i];
if (IsWhitespace(c)) {
if (!prev_is_whitespace) {
ss << ' ';
}
prev_is_whitespace = true;
}
else {
ss << c;
prev_is_whitespace = false;
}
}
return ss.str();
}
std::string StringUtils::XmlEscape(const std::string& str) {
std::ostringstream ss;
for (std::size_t i = 0; i < str.size(); i++) {
char c = str[i];
if (c == '&') {
ss << "&amp;";
}
else if (c == '"') {
ss << "&quot;";
}
else if (c == '\'') {
ss << "&apos;";
}
else if (c == '<') {
ss << "&lt;";
}
else if (c == '>') {
ss << "&gt;";
}
else {
ss << c;
}
}
return ss.str();
}
std::string StringUtils::ToString(const std::string& str) {
return str;
}
std::string StringUtils::ToString(bool obj) {
return (obj)?"true":"false";
}
std::string StringUtils::ToUpper(const std::string& str) {
std::vector<char> output;
output.reserve(str.size());
for (char c : str) {
output.push_back((char)toupper((int)c));
}
return std::string(output.begin(), output.end());
}
std::string StringUtils::ToLower(const std::string& str) {
std::ostringstream ss;
for (char c : str) {
ss << c;
}
return ss.str();
}
} // namespace quicksand

View File

@ -0,0 +1,98 @@
#pragma once
#include <string>
#include <sstream>
#include <stdarg.h>
#include <vector>
#include <stdint.h>
#include "microsoft/shortlist/utils/PrintTypes.h"
namespace quicksand {
class StringUtils {
public:
template <typename T>
static std::string Join(const std::string& joiner, const T& items);
template <typename T>
static std::string Join(const std::string& joiner, const T * items, int32_t length);
static std::string Join(const std::string& joiner, const uint8_t * items, int32_t length);
static std::string Join(const std::string& joiner, const int8_t * items, int32_t length);
static std::vector<std::string> Split(const std::string& input, char splitter);
static std::vector<std::string> Split(const std::string& input, const std::string& splitter);
static std::vector<std::string> SplitFileList(const std::string& input);
static std::string PrintString(const char * format, ...);
static std::string VarArgsToString(const char * format, va_list args);
static std::vector<std::string> WhitespaceTokenize(const std::string& input);
static std::string CleanupWhitespace(const std::string& input);
static std::string ToString(const std::string& str);
static std::string ToString(bool obj);
template <typename T>
static std::string ToString(const T& obj);
static std::string XmlEscape(const std::string& str);
static std::vector<std::string> SplitIntoLines(const std::string& input);
static bool StartsWith(const std::string& str, const std::string& prefix);
static bool EndsWith(const std::string& str, const std::string& suffix);
inline static bool IsWhitespace(char c) {
return (c == ' ' || c == '\t' || c == '\n' || c == '\r');
}
// This should only be used for ASCII, e.g., filenames, NOT for language data
static std::string ToLower(const std::string& str);
// This should only be used for ASCII, e.g., filenames, NOT for language data
static std::string ToUpper(const std::string& str);
};
template <typename T>
std::string StringUtils::Join(const std::string& joiner, const T& items) {
std::ostringstream ss;
bool first = true;
for (auto it = items.begin(); it != items.end(); it++) {
if (!first) {
ss << joiner;
}
ss << (*it);
first = false;
}
return ss.str();
}
template <typename T>
std::string StringUtils::Join(const std::string& joiner, const T * items, int32_t length) {
std::ostringstream ss;
for (int32_t i = 0; i < length; i++) {
if (i != 0) {
ss << joiner;
}
ss << items[i];
}
return ss.str();
}
template <typename T>
std::string StringUtils::ToString(const T& obj) {
std::ostringstream ss;
ss << obj;
return ss.str();
}
} // namespace quicksand

View File

@ -60,8 +60,7 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
if(options_->hasAndNotEmpty("shortlist"))
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();