mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 18185: Support for Microsoft legacy binary shortlist
Adds support for Microsoft-internal binary shortlist format.
This commit is contained in:
parent
a1aaa32c6a
commit
e08c52a8df
@ -40,6 +40,7 @@ set(MARIAN_SOURCES
|
||||
data/corpus_sqlite.cpp
|
||||
data/corpus_nbest.cpp
|
||||
data/text_input.cpp
|
||||
data/shortlist.cpp
|
||||
|
||||
3rd_party/cnpy/cnpy.cpp
|
||||
3rd_party/ExceptionWithCallStack.cpp
|
||||
@ -107,10 +108,15 @@ set(MARIAN_SOURCES
|
||||
training/validator.cpp
|
||||
training/communicator.cpp
|
||||
|
||||
# this is only compiled to catch build errors, but not linked
|
||||
# this is only compiled to catch build errors
|
||||
microsoft/quicksand.cpp
|
||||
microsoft/cosmos.cpp
|
||||
|
||||
# copied from quicksand to be able to read binary shortlist
|
||||
microsoft/shortlist/utils/Converter.cpp
|
||||
microsoft/shortlist/utils/StringUtils.cpp
|
||||
microsoft/shortlist/utils/ParameterTree.cpp
|
||||
|
||||
$<TARGET_OBJECTS:libyaml-cpp>
|
||||
$<TARGET_OBJECTS:SQLiteCpp>
|
||||
$<TARGET_OBJECTS:pathie-cpp>
|
||||
|
@ -546,7 +546,6 @@ void FactoredVocab::constructNormalizationInfoForVocab() {
|
||||
/*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const {
|
||||
for (; num-- > 0; ptr++) {
|
||||
auto word = Word::fromWordIndex(*ptr);
|
||||
auto wordString = word2string(word);
|
||||
auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first;
|
||||
*ptr = (WordIndex)lemmaIndex;
|
||||
}
|
||||
|
153
src/data/shortlist.cpp
Normal file
153
src/data/shortlist.cpp
Normal file
@ -0,0 +1,153 @@
|
||||
#include "data/shortlist.h"
|
||||
#include "microsoft/shortlist/utils/ParameterTree.h"
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
|
||||
// cast current void pointer to T pointer and move forward by num elements
|
||||
template <typename T>
|
||||
const T* get(const void*& current, size_t num = 1) {
|
||||
const T* ptr = (const T*)current;
|
||||
current = (const T*)current + num;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx,
|
||||
size_t /*trgIdx*/,
|
||||
bool /*shared*/)
|
||||
: options_(options),
|
||||
srcVocab_(srcVocab),
|
||||
trgVocab_(trgVocab),
|
||||
srcIdx_(srcIdx) {
|
||||
std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
|
||||
|
||||
ABORT_IF(vals.empty(), "No path to filter path given");
|
||||
std::string fname = vals[0];
|
||||
|
||||
auto firstNum = vals.size() > 1 ? std::stoi(vals[1]) : 0;
|
||||
auto bestNum = vals.size() > 2 ? std::stoi(vals[2]) : 0;
|
||||
float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
|
||||
|
||||
if(firstNum != 0 || bestNum != 0 || threshold != 0) {
|
||||
LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored.");
|
||||
}
|
||||
|
||||
mmap_ = mio::mmap_source(fname); // memory-map the binary file once
|
||||
const void* current = mmap_.data(); // pointer iterator over binary file
|
||||
|
||||
// compare magic number in binary file to make sure we are reading the right thing
|
||||
const int32_t MAGIC_NUMBER = 1234567890;
|
||||
int32_t header_magic_number = *get<int32_t>(current);
|
||||
ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number");
|
||||
|
||||
auto config = ::quicksand::ParameterTree::FromBinaryReader(current);
|
||||
use16bit_ = config->GetBoolReq("use_16_bit");
|
||||
|
||||
LOG(info, "[data] Mapping Quicksand shortlist from {}", fname);
|
||||
|
||||
idSize_ = sizeof(int32_t);
|
||||
if (use16bit_) {
|
||||
idSize_ = sizeof(uint16_t);
|
||||
}
|
||||
|
||||
// mmap the binary shortlist pieces
|
||||
numDefaultIds_ = *get<int32_t>(current);
|
||||
defaultIds_ = get<int32_t>(current, numDefaultIds_);
|
||||
numSourceIds_ = *get<int32_t>(current);
|
||||
sourceLengths_ = get<int32_t>(current, numSourceIds_);
|
||||
sourceOffsets_ = get<int32_t>(current, numSourceIds_);
|
||||
numShortlistIds_ = *get<int32_t>(current);
|
||||
sourceToShortlistIds_ = get<uint8_t>(current, idSize_ * numShortlistIds_);
|
||||
|
||||
// display parameters
|
||||
LOG(info,
|
||||
"[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids",
|
||||
numSourceIds_,
|
||||
numDefaultIds_,
|
||||
numShortlistIds_);
|
||||
}
|
||||
|
||||
Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
|
||||
auto srcBatch = (*batch)[srcIdx_];
|
||||
auto maxShortlistSize = trgVocab_->size();
|
||||
|
||||
std::unordered_set<int32_t> indexSet;
|
||||
for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) {
|
||||
int32_t id = defaultIds_[i];
|
||||
indexSet.insert(id);
|
||||
}
|
||||
|
||||
// State
|
||||
std::vector<std::pair<const uint8_t*, int32_t>> curShortlists(maxShortlistSize);
|
||||
auto curShortlistIt = curShortlists.begin();
|
||||
|
||||
// Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank.
|
||||
// E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ...
|
||||
int32_t maxLength = 0;
|
||||
for (Word word : srcBatch->data()) {
|
||||
int32_t sourceId = (int32_t)word.toWordIndex();
|
||||
srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1);
|
||||
|
||||
if (sourceId < numSourceIds_) { // if it's a valid source id
|
||||
const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping
|
||||
int32_t length = sourceLengths_[sourceId]; // how many mappings are there
|
||||
curShortlistIt->first = curShortlistIds;
|
||||
curShortlistIt->second = length;
|
||||
curShortlistIt++;
|
||||
|
||||
if (length > maxLength)
|
||||
maxLength = length;
|
||||
}
|
||||
}
|
||||
|
||||
// collect the actual shortlist mappings
|
||||
for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) {
|
||||
for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) {
|
||||
int32_t length = curShortlists[j].second;
|
||||
if (i < length) {
|
||||
const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first;
|
||||
int32_t id = 0;
|
||||
if (use16bit_) {
|
||||
const uint16_t* source_shortlist_ids = reinterpret_cast<const uint16_t*>(source_shortlist_ids_bytes);
|
||||
id = (int32_t)source_shortlist_ids[i];
|
||||
}
|
||||
else {
|
||||
const int32_t* source_shortlist_ids = reinterpret_cast<const int32_t*>(source_shortlist_ids_bytes);
|
||||
id = source_shortlist_ids[i];
|
||||
}
|
||||
indexSet.insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// turn into vector and sort (selected indices)
|
||||
std::vector<WordIndex> indices;
|
||||
indices.reserve(indexSet.size());
|
||||
for(auto i : indexSet)
|
||||
indices.push_back((WordIndex)i);
|
||||
|
||||
std::sort(indices.begin(), indices.end());
|
||||
return New<Shortlist>(indices);
|
||||
}
|
||||
|
||||
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx,
|
||||
size_t trgIdx,
|
||||
bool shared) {
|
||||
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
|
||||
ABORT_IF(vals.empty(), "No path to shortlist given");
|
||||
std::string fname = vals[0];
|
||||
if(filesystem::Path(fname).extension().string() == ".bin") {
|
||||
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
} else {
|
||||
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
@ -5,6 +5,7 @@
|
||||
#include "common/file_stream.h"
|
||||
#include "data/corpus_base.h"
|
||||
#include "data/types.h"
|
||||
#include "mio/mio.hpp"
|
||||
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
@ -292,5 +293,51 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Legacy binary shortlist for Microsoft-internal use.
|
||||
*/
|
||||
class QuicksandShortlistGenerator : public ShortlistGenerator {
|
||||
private:
|
||||
Ptr<Options> options_;
|
||||
Ptr<const Vocab> srcVocab_;
|
||||
Ptr<const Vocab> trgVocab_;
|
||||
|
||||
size_t srcIdx_;
|
||||
|
||||
mio::mmap_source mmap_;
|
||||
|
||||
// all the quicksand bits go here
|
||||
bool use16bit_{false};
|
||||
int32_t numDefaultIds_;
|
||||
int32_t idSize_;
|
||||
const int32_t* defaultIds_{nullptr};
|
||||
int32_t numSourceIds_{0};
|
||||
const int32_t* sourceLengths_{nullptr};
|
||||
const int32_t* sourceOffsets_{nullptr};
|
||||
int32_t numShortlistIds_{0};
|
||||
const uint8_t* sourceToShortlistIds_{nullptr};
|
||||
|
||||
public:
|
||||
QuicksandShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx = 0,
|
||||
size_t trgIdx = 1,
|
||||
bool shared = false);
|
||||
|
||||
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
|
||||
};
|
||||
|
||||
/*
|
||||
Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
|
||||
unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
|
||||
*/
|
||||
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx = 0,
|
||||
size_t trgIdx = 1,
|
||||
bool shared = false);
|
||||
|
||||
} // namespace data
|
||||
} // namespace marian
|
||||
|
25
src/microsoft/shortlist/logging/LoggerMacros.h
Normal file
25
src/microsoft/shortlist/logging/LoggerMacros.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
// Do NOT include this file directly except in special circumstances.
|
||||
// (E.g., you want to define macros which call these but don't want to include Logger.h everywhere).
|
||||
// Normally you should include logging/Logger.h
|
||||
|
||||
#define LOG_WRITE(format, ...) do {\
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define LOG_WRITE_STRING(str) do {\
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define LOG_ERROR(format, ...) do {\
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define LOG_ERROR_AND_THROW(format, ...) do {\
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define DECODING_LOGIC_ERROR(format, ...) do {\
|
||||
abort(); \
|
||||
} while (0)
|
59
src/microsoft/shortlist/utils/Converter.cpp
Normal file
59
src/microsoft/shortlist/utils/Converter.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
#include "microsoft/shortlist/utils/Converter.h"
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
|
||||
|
||||
int64_t Converter::ToInt64(const std::string& str) {
|
||||
return ConvertSingleInternal<int64_t>(str, "int64_t");
|
||||
}
|
||||
|
||||
uint64_t Converter::ToUInt64(const std::string& str) {
|
||||
return ConvertSingleInternal<uint64_t>(str, "int64_t");
|
||||
}
|
||||
|
||||
int32_t Converter::ToInt32(const std::string& str) {
|
||||
return ConvertSingleInternal<int32_t>(str, "int32_t");
|
||||
}
|
||||
|
||||
float Converter::ToFloat(const std::string& str) {
|
||||
// In case the value is out of range of a 32-bit float, but in range of a 64-bit double,
|
||||
// it's better to convert as a double and then do the conersion.
|
||||
return (float)ConvertSingleInternal<double>(str, "float");
|
||||
}
|
||||
|
||||
double Converter::ToDouble(const std::string& str) {
|
||||
return ConvertSingleInternal<double>(str, "double");
|
||||
}
|
||||
|
||||
bool Converter::ToBool(const std::string& str) {
|
||||
bool value = false;
|
||||
if (!TryConvert(str, /* out */ value)) {
|
||||
LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type 'bool'", str.c_str());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
std::vector<int32_t> Converter::ToInt32Vector(const std::vector<std::string>& items) {
|
||||
return ConvertVectorInternal<int32_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int32_t");
|
||||
}
|
||||
|
||||
std::vector<int64_t> Converter::ToInt64Vector(const std::vector<std::string>& items) {
|
||||
return ConvertVectorInternal<int64_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int64_t");
|
||||
}
|
||||
|
||||
std::vector<float> Converter::ToFloatVector(const std::vector<std::string>& items) {
|
||||
return ConvertVectorInternal<float, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "float");
|
||||
}
|
||||
|
||||
std::vector<double> Converter::ToDoubleVector(const std::vector<std::string>& items) {
|
||||
return ConvertVectorInternal<double, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "double");
|
||||
}
|
||||
|
||||
void Converter::HandleConversionError(const std::string& str, const char * type_name) {
|
||||
str; type_name; // make compiler happy
|
||||
LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type '%s'", str.c_str(), type_name);
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
83
src/microsoft/shortlist/utils/Converter.h
Normal file
83
src/microsoft/shortlist/utils/Converter.h
Normal file
@ -0,0 +1,83 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
class Converter {
|
||||
public:
|
||||
static int32_t ToInt32(const std::string& str);
|
||||
|
||||
static int64_t ToInt64(const std::string& str);
|
||||
|
||||
static uint64_t ToUInt64(const std::string& str);
|
||||
|
||||
static float ToFloat(const std::string& str);
|
||||
|
||||
static double ToDouble(const std::string& str);
|
||||
|
||||
static bool ToBool(const std::string& str);
|
||||
|
||||
static std::vector<int32_t> ToInt32Vector(const std::vector<std::string>& items);
|
||||
|
||||
static std::vector<int64_t> ToInt64Vector(const std::vector<std::string>& items);
|
||||
|
||||
static std::vector<float> ToFloatVector(const std::vector<std::string>& items);
|
||||
|
||||
static std::vector<double> ToDoubleVector(const std::vector<std::string>& items);
|
||||
|
||||
static bool TryConvert(const std::string& str, /* out*/ bool& obj) {
|
||||
if (str == "True" || str == "true" || str == "TRUE" || str == "Yes" || str == "yes" || str == "1") {
|
||||
obj = true;
|
||||
return true;
|
||||
}
|
||||
else if (str == "False" || str == "false" || str == "FALSE" || str == "No" || str == "no" || str == "0") {
|
||||
obj = false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool TryConvert(const std::string& str, /* out*/ T& value) {
|
||||
std::istringstream ss(str);
|
||||
value = T();
|
||||
if (!(ss >> value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static T ConvertSingleInternal(const std::string& str, const char * type_name);
|
||||
|
||||
template <typename T, typename I>
|
||||
static std::vector<T> ConvertVectorInternal(I begin, I end, const char * type_name);
|
||||
|
||||
static void HandleConversionError(const std::string& str, const char * type_name);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T Converter::ConvertSingleInternal(const std::string& str, const char * type_name) {
|
||||
std::istringstream ss(str);
|
||||
T value = T();
|
||||
if (!(ss >> value)) {
|
||||
HandleConversionError(str, type_name);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * type_name) {
|
||||
std::vector<T> items;
|
||||
for (I it = begin; it != end; it++) {
|
||||
items.push_back(ConvertSingleInternal<T>(*it, type_name));
|
||||
}
|
||||
return items;
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
417
src/microsoft/shortlist/utils/ParameterTree.cpp
Normal file
417
src/microsoft/shortlist/utils/ParameterTree.cpp
Normal file
@ -0,0 +1,417 @@
|
||||
#include "microsoft/shortlist/utils/ParameterTree.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "microsoft/shortlist/utils/StringUtils.h"
|
||||
#include "microsoft/shortlist/utils/Converter.h"
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
|
||||
std::shared_ptr<ParameterTree> ParameterTree::m_empty_tree = std::make_shared<ParameterTree>("params");
|
||||
|
||||
ParameterTree::ParameterTree() {
|
||||
m_name = "root";
|
||||
}
|
||||
|
||||
ParameterTree::ParameterTree(const std::string& name) {
|
||||
m_name = name;
|
||||
}
|
||||
|
||||
ParameterTree::~ParameterTree() {
|
||||
}
|
||||
|
||||
void ParameterTree::Clear() {
|
||||
|
||||
}
|
||||
|
||||
void ParameterTree::ReplaceVariables(
|
||||
const std::unordered_map<std::string, std::string>& vars,
|
||||
bool error_on_unknown_vars)
|
||||
{
|
||||
ReplaceVariablesInternal(vars, error_on_unknown_vars);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterFloat(const std::string& name, float * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterDouble(const std::string& name, double * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterBool(const std::string& name, bool * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param);
|
||||
}
|
||||
|
||||
void ParameterTree::RegisterString(const std::string& name, std::string * param) {
|
||||
RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param);
|
||||
}
|
||||
|
||||
std::shared_ptr<ParameterTree> ParameterTree::FromBinaryReader(const void*& current) {
|
||||
std::shared_ptr<ParameterTree> root = std::make_shared<ParameterTree>();
|
||||
root->ReadBinary(current);
|
||||
return root;
|
||||
}
|
||||
|
||||
void ParameterTree::SetRegisteredParams() {
|
||||
for (std::size_t i = 0; i < m_registered_params.size(); i++) {
|
||||
const RegisteredParam& rp = m_registered_params[i];
|
||||
switch (rp.Type()) {
|
||||
case PARAM_TYPE_INT32:
|
||||
(*(int32_t *)rp.Data()) = GetInt32Req(rp.Name());
|
||||
break;
|
||||
case PARAM_TYPE_INT64:
|
||||
(*(int64_t *)rp.Data()) = GetInt64Req(rp.Name());
|
||||
break;
|
||||
default:
|
||||
LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToInt32(*value);
|
||||
}
|
||||
|
||||
int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToInt64(*value);
|
||||
}
|
||||
|
||||
uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToUInt64(*value);
|
||||
}
|
||||
|
||||
double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToDouble(*value);
|
||||
}
|
||||
|
||||
float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToFloat(*value);
|
||||
}
|
||||
|
||||
std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return (*value);
|
||||
}
|
||||
|
||||
bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Converter::ToBool(*value);
|
||||
}
|
||||
|
||||
int32_t ParameterTree::GetInt32Req(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToInt32(value);
|
||||
}
|
||||
|
||||
uint64_t ParameterTree::GetUInt64Req(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToUInt64(value);
|
||||
}
|
||||
|
||||
int64_t ParameterTree::GetInt64Req(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToInt64(value);
|
||||
}
|
||||
|
||||
double ParameterTree::GetDoubleReq(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToDouble(value);
|
||||
}
|
||||
|
||||
float ParameterTree::GetFloatReq(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToFloat(value);
|
||||
}
|
||||
|
||||
bool ParameterTree::GetBoolReq(const std::string& name) const {
|
||||
std::string value = GetStringReq(name);
|
||||
return Converter::ToBool(value);
|
||||
}
|
||||
|
||||
std::string ParameterTree::GetStringReq(const std::string& name) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str());
|
||||
}
|
||||
return (*value);
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterTree::GetFileListReq(const std::string& name) const {
|
||||
std::vector<std::string> output = GetFileListOptional(name);
|
||||
if (output.size() == 0) {
|
||||
LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str());
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterTree::GetFileListOptional(const std::string& name) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr || (*value).size() == 0) {
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
std::vector<std::string> all_files = StringUtils::Split(*value, ";");
|
||||
return all_files;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const {
|
||||
std::string value = GetStringReq(name);
|
||||
std::vector<std::string> output = StringUtils::Split(value, sep);
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const {
|
||||
std::string value = GetStringOr(name, "");
|
||||
std::vector<std::string> output = StringUtils::Split(value, sep);
|
||||
return output;
|
||||
}
|
||||
|
||||
std::shared_ptr<ParameterTree> ParameterTree::GetChildReq(const std::string& name) const {
|
||||
for (const auto& child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
return child;
|
||||
}
|
||||
}
|
||||
LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str());
|
||||
return nullptr; // never happens
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ParameterTree> ParameterTree::GetChildOrEmpty(const std::string& name) const {
|
||||
for (const auto& child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
return child;
|
||||
}
|
||||
}
|
||||
return std::make_shared<ParameterTree>();
|
||||
}
|
||||
|
||||
// cast current void pointer to T pointer and move forward by num elements
|
||||
template <typename T>
|
||||
const T* get(const void*& current, size_t num = 1) {
|
||||
const T* ptr = (const T*)current;
|
||||
current = (const T*)current + num;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void ParameterTree::ReadBinary(const void*& current) {
|
||||
auto nameLength = *get<int32_t>(current);
|
||||
auto nameBytes = get<char>(current, nameLength);
|
||||
m_name = std::string(nameBytes, nameBytes + nameLength);
|
||||
|
||||
auto textLength = *get<int32_t>(current);
|
||||
auto textBytes = get<char>(current, textLength);
|
||||
m_text = std::string(textBytes, textBytes + textLength);
|
||||
|
||||
int32_t num_children = *get<int32_t>(current);
|
||||
m_children.resize(num_children);
|
||||
for (int32_t i = 0; i < num_children; i++) {
|
||||
m_children[i].reset(new ParameterTree());
|
||||
m_children[i]->ReadBinary(current);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector< std::shared_ptr<ParameterTree> > ParameterTree::GetChildren(const std::string& name) const {
|
||||
std::vector< std::shared_ptr<ParameterTree> > children;
|
||||
for (std::shared_ptr<ParameterTree> child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
children.push_back(child);
|
||||
}
|
||||
}
|
||||
return children;
|
||||
}
|
||||
|
||||
void ParameterTree::AddParam(const std::string& name, const std::string& text) {
|
||||
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
|
||||
child->SetText(text);
|
||||
m_children.push_back(child);
|
||||
}
|
||||
|
||||
void ParameterTree::SetParam(const std::string& name, const std::string& text) {
|
||||
for (const auto& child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
child->SetText(text);
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
|
||||
child->SetText(text);
|
||||
m_children.push_back(child);
|
||||
}
|
||||
|
||||
void ParameterTree::AddChild(std::shared_ptr<ParameterTree> child) {
|
||||
m_children.push_back(child);
|
||||
}
|
||||
|
||||
bool ParameterTree::HasParam(const std::string& name) const {
|
||||
const std::string * value = GetParamInternal(name);
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterTree::HasChild(const std::string& name) const {
|
||||
for (const auto& child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string ParameterTree::ToString() const {
|
||||
std::ostringstream ss;
|
||||
ToStringInternal(0, ss);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const std::string * ParameterTree::GetParamInternal(const std::string& name) const {
|
||||
for (const auto& child : m_children) {
|
||||
if (child->Name() == name) {
|
||||
return &(child->Text());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) {
|
||||
if (m_registered_param_names.find(name) != m_registered_param_names.end()) {
|
||||
LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str());
|
||||
}
|
||||
m_registered_params.push_back(RegisteredParam(name, type, param));
|
||||
m_registered_param_names.insert(name);
|
||||
}
|
||||
|
||||
void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const {
|
||||
for (int32_t i = 0; i < 2*depth; i++) {
|
||||
ss << " ";
|
||||
}
|
||||
ss << "<" << m_name << ">";
|
||||
if (m_children.size() > 0) {
|
||||
ss << "\n";
|
||||
for (const std::shared_ptr<ParameterTree>& child : m_children) {
|
||||
child->ToStringInternal(depth+1, ss);
|
||||
}
|
||||
for (int32_t i = 0; i < 2 * depth; i++) {
|
||||
ss << " ";
|
||||
}
|
||||
ss << "</" << m_name << ">\n";
|
||||
}
|
||||
else {
|
||||
ss << m_text << "</" << m_name << ">\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ParameterTree> ParameterTree::Clone() const {
|
||||
std::shared_ptr<ParameterTree> node = std::make_shared<ParameterTree>(m_name);
|
||||
node->m_text = m_text;
|
||||
for (auto& child : m_children) {
|
||||
node->m_children.push_back(child->Clone());
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void ParameterTree::Merge(const ParameterTree& other) {
|
||||
m_name = other.m_name;
|
||||
m_text = other.m_text;
|
||||
for (auto& other_child : other.m_children) {
|
||||
if (HasChild(other_child->Name())) {
|
||||
auto my_child = GetChildReq(other_child->Name());
|
||||
if (other_child->Text() != "" && my_child->Text() != "") {
|
||||
my_child->SetText(other_child->Text());
|
||||
}
|
||||
else {
|
||||
my_child->Merge(*other_child);
|
||||
}
|
||||
}
|
||||
else {
|
||||
m_children.push_back(other_child->Clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ParameterTree::ReplaceVariablesInternal(
|
||||
const std::unordered_map<std::string, std::string>& vars,
|
||||
bool error_on_unknown_vars)
|
||||
{
|
||||
std::size_t offset = 0;
|
||||
std::ostringstream ss;
|
||||
while (true) {
|
||||
std::size_t s_pos = m_text.find("$$", offset);
|
||||
if (s_pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
std::size_t e_pos = m_text.find("$$", s_pos + 2);
|
||||
if (e_pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (offset != s_pos) {
|
||||
ss << m_text.substr(offset, s_pos-offset);
|
||||
}
|
||||
|
||||
std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2));
|
||||
auto it = vars.find(var_name);
|
||||
if (it != vars.end()) {
|
||||
std::string value = it->second;
|
||||
ss << value;
|
||||
}
|
||||
else {
|
||||
if (error_on_unknown_vars) {
|
||||
LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str());
|
||||
}
|
||||
else {
|
||||
ss << "$$" << var_name << "$$";
|
||||
}
|
||||
}
|
||||
offset = e_pos + 2;
|
||||
}
|
||||
ss << m_text.substr(offset);
|
||||
|
||||
m_text = ss.str();
|
||||
|
||||
for (auto& child : m_children) {
|
||||
child->ReplaceVariablesInternal(vars, error_on_unknown_vars);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
|
185
src/microsoft/shortlist/utils/ParameterTree.h
Normal file
185
src/microsoft/shortlist/utils/ParameterTree.h
Normal file
@ -0,0 +1,185 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
#include "microsoft/shortlist/utils/StringUtils.h"
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
class ParameterTree {
|
||||
private:
|
||||
enum ParameterType {
|
||||
PARAM_TYPE_INT32,
|
||||
PARAM_TYPE_INT64,
|
||||
PARAM_TYPE_UINT64,
|
||||
PARAM_TYPE_FLOAT,
|
||||
PARAM_TYPE_DOUBLE,
|
||||
PARAM_TYPE_BOOL,
|
||||
PARAM_TYPE_STRING
|
||||
};
|
||||
|
||||
class RegisteredParam {
|
||||
private:
|
||||
std::string m_name;
|
||||
ParameterType m_type;
|
||||
void * m_data;
|
||||
|
||||
public:
|
||||
RegisteredParam() {}
|
||||
|
||||
RegisteredParam(const std::string& name,
|
||||
ParameterType type,
|
||||
void * data)
|
||||
{
|
||||
m_name = name;
|
||||
m_type = type;
|
||||
m_data = data;
|
||||
}
|
||||
|
||||
const std::string& Name() const {return m_name;}
|
||||
const ParameterType& Type() const {return m_type;}
|
||||
void * Data() const {return m_data;}
|
||||
};
|
||||
|
||||
static std::shared_ptr<ParameterTree> m_empty_tree;
|
||||
|
||||
std::string m_name;
|
||||
|
||||
std::string m_text;
|
||||
|
||||
std::vector< std::shared_ptr<ParameterTree> > m_children;
|
||||
|
||||
std::unordered_set<std::string> m_registered_param_names;
|
||||
|
||||
std::vector<RegisteredParam> m_registered_params;
|
||||
|
||||
public:
|
||||
ParameterTree();
|
||||
|
||||
ParameterTree(const std::string& name);
|
||||
|
||||
~ParameterTree();
|
||||
|
||||
inline const std::string& Text() const { return m_text; }
|
||||
inline void SetText(const std::string& text) { m_text = text; }
|
||||
|
||||
inline const std::string& Name() const { return m_name; }
|
||||
inline void SetName(const std::string& name) { m_name = name; }
|
||||
|
||||
void Clear();
|
||||
|
||||
void ReplaceVariables(
|
||||
const std::unordered_map<std::string, std::string>& vars,
|
||||
bool error_on_unknown_vars = true);
|
||||
|
||||
void RegisterInt32(const std::string& name, int32_t * param);
|
||||
|
||||
void RegisterInt64(const std::string& name, int64_t * param);
|
||||
|
||||
void RegisterFloat(const std::string& name, float * param);
|
||||
|
||||
void RegisterDouble(const std::string& name, double * param);
|
||||
|
||||
void RegisterBool(const std::string& name, bool * param);
|
||||
|
||||
void RegisterString(const std::string& name, std::string * param);
|
||||
|
||||
static std::shared_ptr<ParameterTree> FromBinaryReader(const void*& current);
|
||||
|
||||
void SetRegisteredParams();
|
||||
|
||||
int32_t GetInt32Req(const std::string& name) const;
|
||||
|
||||
int64_t GetInt64Req(const std::string& name) const;
|
||||
|
||||
uint64_t GetUInt64Req(const std::string& name) const;
|
||||
|
||||
double GetDoubleReq(const std::string& name) const;
|
||||
|
||||
float GetFloatReq(const std::string& name) const;
|
||||
|
||||
std::string GetStringReq(const std::string& name) const;
|
||||
|
||||
bool GetBoolReq(const std::string& name) const;
|
||||
|
||||
int32_t GetInt32Or(const std::string& name, int32_t defaultValue) const;
|
||||
|
||||
int64_t GetInt64Or(const std::string& name, int64_t defaultValue) const;
|
||||
|
||||
uint64_t GetUInt64Or(const std::string& name, uint64_t defaultValue) const;
|
||||
|
||||
std::string GetStringOr(const std::string& name, const std::string& defaultValue) const;
|
||||
|
||||
double GetDoubleOr(const std::string& name, double defaultValue) const;
|
||||
|
||||
float GetFloatOr(const std::string& name, float defaultValue) const;
|
||||
|
||||
bool GetBoolOr(const std::string& name, bool defaultValue) const;
|
||||
|
||||
std::vector<std::string> GetFileListReq(const std::string& name) const;
|
||||
|
||||
std::vector<std::string> GetFileListOptional(const std::string& name) const;
|
||||
|
||||
std::vector<std::string> GetStringListReq(const std::string& name, const std::string& sep = " ") const;
|
||||
|
||||
std::vector<std::string> GetStringListOptional(const std::string& name, const std::string& sep = " ") const;
|
||||
|
||||
std::shared_ptr<ParameterTree> GetChildReq(const std::string& name) const;
|
||||
|
||||
std::shared_ptr<ParameterTree> GetChildOrEmpty(const std::string& name) const;
|
||||
|
||||
std::vector< std::shared_ptr<ParameterTree> > GetChildren(const std::string& name) const;
|
||||
|
||||
inline const std::vector< std::shared_ptr<ParameterTree> >& GetChildren() const { return m_children; }
|
||||
|
||||
void ReadBinary(const void*& current);
|
||||
|
||||
void AddParam(const std::string& name, const std::string& text);
|
||||
|
||||
template <typename T>
|
||||
void AddParam(const std::string& name, const T& obj);
|
||||
|
||||
void SetParam(const std::string& name, const std::string& text);
|
||||
|
||||
template <typename T>
|
||||
void SetParam(const std::string& name, const T& obj);
|
||||
|
||||
void AddChild(std::shared_ptr<ParameterTree> child);
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
bool HasChild(const std::string& name) const;
|
||||
|
||||
bool HasParam(const std::string& name) const;
|
||||
|
||||
std::shared_ptr<ParameterTree> Clone() const;
|
||||
|
||||
void Merge(const ParameterTree& other);
|
||||
|
||||
private:
|
||||
void ReplaceVariablesInternal(
|
||||
const std::unordered_map<std::string, std::string>& vars,
|
||||
bool error_on_unknown_vars);
|
||||
|
||||
void RegisterItemInternal(const std::string& name, ParameterType type, void * param);
|
||||
|
||||
const std::string * GetParamInternal(const std::string& name) const;
|
||||
|
||||
void ToStringInternal(int32_t depth, std::ostream& ss) const;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void ParameterTree::AddParam(const std::string& name, const T& obj) {
|
||||
AddParam(name, StringUtils::ToString(obj));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterTree::SetParam(const std::string& name, const T& obj) {
|
||||
SetParam(name, StringUtils::ToString(obj));
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
16
src/microsoft/shortlist/utils/PrintTypes.h
Normal file
16
src/microsoft/shortlist/utils/PrintTypes.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
#ifdef QUICKSAND_WINDOWS_BUILD
|
||||
#define PI32 "d"
|
||||
#define PI64 "lld"
|
||||
#define PU32 "u"
|
||||
#define PU64 "llu"
|
||||
#else
|
||||
#define PI32 PRId32
|
||||
#define PI64 PRId64
|
||||
#define PU32 PRIu32
|
||||
#define PU64 PRIu64
|
||||
#endif
|
||||
|
338
src/microsoft/shortlist/utils/StringUtils.cpp
Normal file
338
src/microsoft/shortlist/utils/StringUtils.cpp
Normal file
@ -0,0 +1,338 @@
|
||||
#include "microsoft/shortlist/utils/StringUtils.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
|
||||
std::string StringUtils::VarArgsToString(const char * format, va_list args) {
|
||||
if (format == nullptr) {
|
||||
LOG_ERROR_AND_THROW("'format' cannot be null in StringUtils::VarArgsToString");
|
||||
}
|
||||
|
||||
std::string output;
|
||||
// Most of the time the stack buffer (5000 chars) will be sufficient.
|
||||
// In cases where this is insufficient, dynamically allocate an appropriately sized buffer
|
||||
char buffer[5000];
|
||||
#ifdef QUICKSAND_WINDOWS_BUILD
|
||||
va_list copy;
|
||||
va_copy(copy, args);
|
||||
int ret = vsnprintf_s(buffer, sizeof(buffer), _TRUNCATE, format, copy);
|
||||
va_end(copy);
|
||||
if (ret >= 0) {
|
||||
output = std::string(buffer, buffer + ret);
|
||||
}
|
||||
else {
|
||||
va_list copy2;
|
||||
va_copy(copy2, args);
|
||||
int needed_size = _vscprintf(format, copy2);
|
||||
va_end(copy2);
|
||||
|
||||
if (needed_size < 0) {
|
||||
LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen");
|
||||
}
|
||||
char * dynamic_buffer = new char[needed_size+1];
|
||||
int ret2 = vsnprintf_s(dynamic_buffer, needed_size+1, _TRUNCATE, format, args);
|
||||
if (ret2 >= 0) {
|
||||
output = std::string(dynamic_buffer, dynamic_buffer + ret2);
|
||||
delete[] dynamic_buffer;
|
||||
}
|
||||
else {
|
||||
output = "";
|
||||
delete[] dynamic_buffer;
|
||||
LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen, "
|
||||
"since we made a call to _vscprintf() to check the dynamic buffer size. The call to _vscprintf() "
|
||||
"returned %d bytes, but apparently that was not enough. This would imply a bug in MSVC's vsnprintf_s implementation.", needed_size);
|
||||
}
|
||||
}
|
||||
#else
|
||||
va_list copy;
|
||||
va_copy(copy, args);
|
||||
int needed_size = vsnprintf(buffer, sizeof(buffer), format, copy);
|
||||
va_end(copy);
|
||||
if (needed_size < (int)sizeof(buffer)) {
|
||||
output = std::string(buffer, buffer + needed_size);
|
||||
}
|
||||
else {
|
||||
char * dynamic_buffer = new char[needed_size+1];
|
||||
int ret = vsnprintf(dynamic_buffer, needed_size + 1, format, args);
|
||||
if (ret >= 0 && ret < needed_size + 1) {
|
||||
output = std::string(dynamic_buffer);
|
||||
delete[] dynamic_buffer;
|
||||
}
|
||||
else {
|
||||
output = "";
|
||||
delete[] dynamic_buffer;
|
||||
LOG_ERROR_AND_THROW("A call to vsnprintf() failed. Return value: %d.",
|
||||
ret);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::SplitIntoLines(const std::string& input) {
|
||||
std::vector<std::string> output;
|
||||
if (input.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
std::size_t start = 0;
|
||||
for (std::size_t i = 0; i < input.size(); i++) {
|
||||
char c = input[i];
|
||||
if (c == '\r' || c == '\n') {
|
||||
output.push_back(std::string(input.begin() + start, input.begin() + i));
|
||||
start = i+1;
|
||||
}
|
||||
if (c == '\r' && i + 1 < input.size() && input[i+1] == '\n') {
|
||||
i++;
|
||||
start = i+1;
|
||||
}
|
||||
}
|
||||
// do NOT put an empty length trailing line (but empty length intermediate lines are fine)
|
||||
if (input.begin() + start != input.end()) {
|
||||
output.push_back(std::string(input.begin() + start, input.end()));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool StringUtils::StartsWith(const std::string& str, const std::string& prefix) {
|
||||
if (str.length() < prefix.length())
|
||||
return false;
|
||||
|
||||
return std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||
}
|
||||
|
||||
bool StringUtils::EndsWith(const std::string& str, const std::string& suffix) {
|
||||
if (str.length() < suffix.length())
|
||||
return false;
|
||||
|
||||
return std::equal(suffix.begin(), suffix.end(), str.end() - suffix.length());
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::SplitFileList(const std::string& input) {
|
||||
std::vector<std::string> output;
|
||||
for (const std::string& s : SplitIntoLines(input)) {
|
||||
for (const std::string& t : Split(s, ";")) {
|
||||
std::string f = CleanupWhitespace(t);
|
||||
output.push_back(f);
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::Split(const std::string& input, char splitter) {
|
||||
std::vector<std::string> output;
|
||||
if (input.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
std::size_t start = 0;
|
||||
for (std::size_t i = 0; i < input.size(); i++) {
|
||||
if (input[i] == splitter) {
|
||||
output.push_back(std::string(input.begin() + start, input.begin() + i));
|
||||
start = i+1;
|
||||
}
|
||||
}
|
||||
output.push_back(std::string(input.begin() + start, input.end()));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::Split(const std::string& input, const std::string& splitter) {
|
||||
std::vector<std::string> output;
|
||||
if (input.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
std::size_t pos = 0;
|
||||
while (true) {
|
||||
std::size_t next_pos = input.find(splitter, pos);
|
||||
if (next_pos == std::string::npos) {
|
||||
output.push_back(std::string(input.begin() + pos, input.end()));
|
||||
break;
|
||||
}
|
||||
else {
|
||||
output.push_back(std::string(input.begin() + pos, input.begin() + next_pos));
|
||||
}
|
||||
pos = next_pos + splitter.size();
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string StringUtils::Join(const std::string& joiner, const uint8_t * items, int32_t length) {
|
||||
std::ostringstream ss;
|
||||
for (int32_t i = 0; i < length; i++) {
|
||||
if (i != 0) {
|
||||
ss << joiner;
|
||||
}
|
||||
ss << (int32_t)(items[i]);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string StringUtils::Join(const std::string& joiner, const int8_t * items, int32_t length) {
|
||||
std::ostringstream ss;
|
||||
for (int32_t i = 0; i < length; i++) {
|
||||
if (i != 0) {
|
||||
ss << joiner;
|
||||
}
|
||||
ss << (int32_t)(items[i]);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string StringUtils::PrintString(const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::string output = StringUtils::VarArgsToString(format, args);
|
||||
va_end(args);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::WhitespaceTokenize(const std::string& input) {
|
||||
std::vector<std::string> output;
|
||||
if (input.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
std::size_t size = input.size();
|
||||
std::size_t start = 0;
|
||||
std::size_t end = size;
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
char c = input[i];
|
||||
if (IsWhitespace(c)) {
|
||||
start++;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
char c = input[size-1-i];
|
||||
if (IsWhitespace(c)) {
|
||||
end--;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (end <= start) {
|
||||
return output;
|
||||
}
|
||||
bool prev_is_whitespace = false;
|
||||
std::size_t token_start = start;
|
||||
for (std::size_t i = start; i < end; i++) {
|
||||
char c = input[i];
|
||||
if (IsWhitespace(c)) {
|
||||
if (!prev_is_whitespace) {
|
||||
output.push_back(std::string(input.begin() + token_start, input.begin() + i));
|
||||
}
|
||||
prev_is_whitespace = true;
|
||||
token_start = i+1;
|
||||
}
|
||||
else {
|
||||
prev_is_whitespace = false;
|
||||
}
|
||||
}
|
||||
output.push_back(std::string(input.begin() + token_start, input.begin() + end));
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string StringUtils::CleanupWhitespace(const std::string& input) {
|
||||
if (input.size() == 0) {
|
||||
return std::string("");
|
||||
}
|
||||
std::size_t size = input.size();
|
||||
std::size_t start = 0;
|
||||
std::size_t end = size;
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
char c = input[i];
|
||||
if (IsWhitespace(c)) {
|
||||
start++;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
char c = input[size-1-i];
|
||||
if (IsWhitespace(c)) {
|
||||
end--;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (end <= start) {
|
||||
return std::string("");
|
||||
}
|
||||
std::ostringstream ss;
|
||||
bool prev_is_whitespace = false;
|
||||
for (std::size_t i = start; i < end; i++) {
|
||||
char c = input[i];
|
||||
if (IsWhitespace(c)) {
|
||||
if (!prev_is_whitespace) {
|
||||
ss << ' ';
|
||||
}
|
||||
prev_is_whitespace = true;
|
||||
}
|
||||
else {
|
||||
ss << c;
|
||||
prev_is_whitespace = false;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string StringUtils::XmlEscape(const std::string& str) {
|
||||
std::ostringstream ss;
|
||||
for (std::size_t i = 0; i < str.size(); i++) {
|
||||
char c = str[i];
|
||||
if (c == '&') {
|
||||
ss << "&";
|
||||
}
|
||||
else if (c == '"') {
|
||||
ss << """;
|
||||
}
|
||||
else if (c == '\'') {
|
||||
ss << "'";
|
||||
}
|
||||
else if (c == '<') {
|
||||
ss << "<";
|
||||
}
|
||||
else if (c == '>') {
|
||||
ss << ">";
|
||||
}
|
||||
else {
|
||||
ss << c;
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string StringUtils::ToString(const std::string& str) {
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string StringUtils::ToString(bool obj) {
|
||||
return (obj)?"true":"false";
|
||||
}
|
||||
|
||||
std::string StringUtils::ToUpper(const std::string& str) {
|
||||
std::vector<char> output;
|
||||
output.reserve(str.size());
|
||||
for (char c : str) {
|
||||
output.push_back((char)toupper((int)c));
|
||||
}
|
||||
return std::string(output.begin(), output.end());
|
||||
}
|
||||
|
||||
std::string StringUtils::ToLower(const std::string& str) {
|
||||
std::ostringstream ss;
|
||||
for (char c : str) {
|
||||
ss << c;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
98
src/microsoft/shortlist/utils/StringUtils.h
Normal file
98
src/microsoft/shortlist/utils/StringUtils.h
Normal file
@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <stdarg.h>
|
||||
#include <vector>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "microsoft/shortlist/utils/PrintTypes.h"
|
||||
|
||||
namespace quicksand {
|
||||
|
||||
class StringUtils {
|
||||
public:
|
||||
template <typename T>
|
||||
static std::string Join(const std::string& joiner, const T& items);
|
||||
|
||||
template <typename T>
|
||||
static std::string Join(const std::string& joiner, const T * items, int32_t length);
|
||||
|
||||
static std::string Join(const std::string& joiner, const uint8_t * items, int32_t length);
|
||||
|
||||
static std::string Join(const std::string& joiner, const int8_t * items, int32_t length);
|
||||
|
||||
static std::vector<std::string> Split(const std::string& input, char splitter);
|
||||
|
||||
static std::vector<std::string> Split(const std::string& input, const std::string& splitter);
|
||||
|
||||
static std::vector<std::string> SplitFileList(const std::string& input);
|
||||
|
||||
static std::string PrintString(const char * format, ...);
|
||||
|
||||
static std::string VarArgsToString(const char * format, va_list args);
|
||||
|
||||
static std::vector<std::string> WhitespaceTokenize(const std::string& input);
|
||||
|
||||
static std::string CleanupWhitespace(const std::string& input);
|
||||
|
||||
static std::string ToString(const std::string& str);
|
||||
|
||||
static std::string ToString(bool obj);
|
||||
|
||||
template <typename T>
|
||||
static std::string ToString(const T& obj);
|
||||
|
||||
static std::string XmlEscape(const std::string& str);
|
||||
|
||||
static std::vector<std::string> SplitIntoLines(const std::string& input);
|
||||
|
||||
static bool StartsWith(const std::string& str, const std::string& prefix);
|
||||
|
||||
static bool EndsWith(const std::string& str, const std::string& suffix);
|
||||
|
||||
inline static bool IsWhitespace(char c) {
|
||||
return (c == ' ' || c == '\t' || c == '\n' || c == '\r');
|
||||
}
|
||||
|
||||
// This should only be used for ASCII, e.g., filenames, NOT for language data
|
||||
static std::string ToLower(const std::string& str);
|
||||
|
||||
// This should only be used for ASCII, e.g., filenames, NOT for language data
|
||||
static std::string ToUpper(const std::string& str);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::string StringUtils::Join(const std::string& joiner, const T& items) {
|
||||
std::ostringstream ss;
|
||||
bool first = true;
|
||||
for (auto it = items.begin(); it != items.end(); it++) {
|
||||
if (!first) {
|
||||
ss << joiner;
|
||||
}
|
||||
ss << (*it);
|
||||
first = false;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string StringUtils::Join(const std::string& joiner, const T * items, int32_t length) {
|
||||
std::ostringstream ss;
|
||||
for (int32_t i = 0; i < length; i++) {
|
||||
if (i != 0) {
|
||||
ss << joiner;
|
||||
}
|
||||
ss << items[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string StringUtils::ToString(const T& obj) {
|
||||
std::ostringstream ss;
|
||||
ss << obj;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
@ -60,8 +60,7 @@ public:
|
||||
auto srcVocab = corpus_->getVocabs()[0];
|
||||
|
||||
if(options_->hasAndNotEmpty("shortlist"))
|
||||
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
|
||||
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
|
||||
|
||||
auto devices = Config::getDevices(options_);
|
||||
numDevices_ = devices.size();
|
||||
|
Loading…
Reference in New Issue
Block a user