merges internal changes to github exteranl repos

This commit is contained in:
Taku Kudo 2023-12-23 07:20:11 +00:00
parent 022f8c3fed
commit 6b32c01286
12 changed files with 178 additions and 47 deletions

View File

@ -194,9 +194,9 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
EncodeResult output;
for (int index = 0; index != -1; index = symbols[index].next) {
CHECK_GE(index, 0);
CHECK_LT(index, static_cast<int>(symbols.size()));
resegment(symbols[index].piece, &output);
if (index >= 0 && index < static_cast<int>(symbols.size())) {
resegment(symbols[index].piece, &output);
}
}
return output;

View File

@ -53,12 +53,6 @@ typedef uint64_t uint64;
static constexpr uint32 kUnicodeError = 0xFFFD;
#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
#define WPATH(path) (::sentencepiece::win32::Utf8ToWide(path).c_str())
#else
#define WPATH(path) (path)
#endif
template <typename T, size_t N>
char (&ArraySizeHelper(T (&array)[N]))[N];

View File

@ -12,16 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <iostream>
#include "filesystem.h"
#include <fstream>
#include <iostream>
#include <memory>
#include "third_party/absl/memory/memory.h"
#include "util.h"
#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
#define WPATH(path) (::sentencepiece::win32::Utf8ToWide(path).c_str())
#define WPATH(path) (::sentencepiece::util::Utf8ToWide(path).c_str())
#else
#define WPATH(path) (path)
#define WPATH(path) (path.data())
#endif
namespace sentencepiece {
@ -32,7 +35,7 @@ class PosixReadableFile : public ReadableFile {
PosixReadableFile(absl::string_view filename, bool is_binary = false)
: is_(filename.empty()
? &std::cin
: new std::ifstream(WPATH(filename.data()),
: new std::ifstream(WPATH(filename),
is_binary ? std::ios::binary | std::ios::in
: std::ios::in)) {
if (!*is_)
@ -70,7 +73,7 @@ class PosixWritableFile : public WritableFile {
PosixWritableFile(absl::string_view filename, bool is_binary = false)
: os_(filename.empty()
? &std::cout
: new std::ofstream(WPATH(filename.data()),
: new std::ofstream(WPATH(filename),
is_binary ? std::ios::binary | std::ios::out
: std::ios::out)) {
if (!*os_)

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "model_interface.h"
#include <algorithm>
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/str_format.h"
@ -68,6 +69,23 @@ void ModelInterface::InitializePieces() {
std::set<absl::string_view> user_defined_symbols;
std::vector<bool> byte_found(256, false);
int pieces_size = 0;
int reserved_id_map_size = 0;
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
const bool is_normal_piece =
(sp.type() == ModelProto::SentencePiece::NORMAL ||
sp.type() == ModelProto::SentencePiece::USER_DEFINED ||
sp.type() == ModelProto::SentencePiece::UNUSED);
if (is_normal_piece) {
++pieces_size;
} else {
++reserved_id_map_size;
}
}
pieces_.reserve(pieces_size);
reserved_id_map_.reserve(reserved_id_map_size);
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
if (sp.piece().empty()) {

View File

@ -315,8 +315,11 @@ PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
key.reserve(dic.size());
for (const auto &it : dic) key.push_back(it.data());
trie_ = absl::make_unique<Darts::DoubleArray>();
CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr));
if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr) != 0) {
LOG(ERROR) << "Failed to build the TRIE for PrefixMatcher";
trie_.reset();
}
}
int PrefixMatcher::PrefixMatch(absl::string_view w, bool *found) const {

View File

@ -14,9 +14,15 @@
#include "sentencepiece_processor.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <map>
#include <memory>
#include <set>
#include <utility>
#include <vector>
#include "common.h"
#include "filesystem.h"
@ -409,7 +415,7 @@ util::Status SentencePieceProcessor::Decode(
SentencePieceText spt;
RETURN_IF_ERROR(Decode(pieces, &spt));
*detokenized = std::move(spt.text());
*detokenized = std::move(*spt.mutable_text());
return util::OkStatus();
}
@ -420,7 +426,7 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,
SentencePieceText spt;
RETURN_IF_ERROR(Decode(ids, &spt));
*detokenized = std::move(spt.text());
*detokenized = std::move(*spt.mutable_text());
return util::OkStatus();
}
@ -623,10 +629,10 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText(
CHECK_EQ_OR_RETURN(consumed, normalized.size())
<< "all normalized characters are not consumed.";
RETURN_IF_ERROR(ApplyExtraOptions(encode_extra_options_, spt));
spt->set_text(input.data(), input.size());
RETURN_IF_ERROR(ApplyExtraOptions(encode_extra_options_, spt));
return util::OkStatus();
} // namespace sentencepiece
@ -695,10 +701,17 @@ util::Status SentencePieceProcessor::SampleEncode(
const auto nbests = model_->NBestEncode(normalized, nbest_size);
CHECK_OR_RETURN(!nbests.empty()) << "NBestEncode returns empty result.";
std::vector<float> probs(nbests.size(), 0.0);
for (size_t i = 0; i < nbests.size(); ++i) {
probs[i] = std::exp(alpha * nbests[i].second);
}
std::vector<double> log_probs;
log_probs.reserve(nbests.size());
std::transform(nbests.begin(), nbests.end(), std::back_inserter(log_probs),
[alpha](const auto &nbest) { return alpha * nbest.second; });
const double Z = log_domain::LogSum(log_probs);
std::vector<double> probs;
probs.reserve(log_probs.size());
std::transform(
log_probs.begin(), log_probs.end(), std::back_inserter(probs),
[Z](const auto &log_prob) { return std::exp(log_prob - Z); });
auto *mt = random::GetRandomGenerator();
std::discrete_distribution<int> dist(probs.begin(), probs.end());
@ -998,6 +1011,8 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
piece->set_id(PieceToId(absl::string_view(model_->eos_piece().data())));
piece->set_piece(model_->eos_piece().data(),
model_->eos_piece().size());
piece->set_begin(spt->text().size());
piece->set_end(spt->text().size());
} break;
case BOS: {
auto *array = spt->mutable_pieces();
@ -1009,6 +1024,8 @@ util::Status SentencePieceProcessor::ApplyExtraOptions(
piece->set_id(PieceToId(absl::string_view(model_->bos_piece().data())));
piece->set_piece(model_->bos_piece().data(),
model_->bos_piece().size());
piece->set_begin(0);
piece->set_end(0);
} break;
case UNK_PIECE: {
for (int i = 0; i < spt->pieces_size(); ++i) {
@ -1097,9 +1114,13 @@ util::Status LoadModelProto(absl::string_view filename,
auto input = filesystem::NewReadableFile(filename, true);
RETURN_IF_ERROR(input->status());
std::string serialized;
CHECK_OR_RETURN(input->ReadAll(&serialized));
CHECK_OR_RETURN(
model_proto->ParseFromArray(serialized.data(), serialized.size()));
if (!input->ReadAll(&serialized)) {
return util::InternalError(absl::StrCat("could not read ", filename));
}
if (!model_proto->ParseFromArray(serialized.data(), serialized.size())) {
return util::InternalError(
absl::StrCat("could not parse ModelProto from ", filename));
}
return util::OkStatus();
}

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_trainer.h"
#include <string>
#include <vector>
@ -20,7 +22,6 @@
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "spec_parser.h"
#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/numbers.h"
@ -197,6 +198,40 @@ util::Status SentencePieceTrainer::Train(
sentence_iterator, serialized_model_proto);
}
namespace {
class VectorSentenceIterator : public SentenceIterator {
public:
explicit VectorSentenceIterator(const std::vector<std::string> &values)
: iter_(values.begin()), end_(values.end()) {}
virtual ~VectorSentenceIterator() {}
virtual bool done() const { return iter_ == end_; }
void Next() override { ++iter_; }
const std::string &value() const override { return *iter_; }
util::Status status() const override { return util::OkStatus(); }
private:
std::vector<std::string>::const_iterator iter_;
std::vector<std::string>::const_iterator end_;
};
} // namespace
// static
util::Status SentencePieceTrainer::Train(
absl::string_view args, const std::vector<std::string> &sentences,
std::string *serialized_model_proto) {
VectorSentenceIterator iter(sentences);
return Train(args, &iter, serialized_model_proto);
}
// static
util::Status SentencePieceTrainer::Train(
const std::unordered_map<std::string, std::string> &kwargs,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto) {
VectorSentenceIterator iter(sentences);
return Train(kwargs, &iter, serialized_model_proto);
}
// static
util::Status SentencePieceTrainer::PopulateNormalizerSpec(
NormalizerSpec *normalizer_spec, bool is_denormalizer) {

View File

@ -89,6 +89,17 @@ class SentencePieceTrainer {
SentenceIterator *sentence_iterator = nullptr,
std::string *serialized_model_proto = nullptr);
// The same as above, but passes the list of sentences.
static util::Status Train(absl::string_view args,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto = nullptr);
// The same as above, but passes the list of sentences.
static util::Status Train(
const std::unordered_map<std::string, std::string> &kwargs,
const std::vector<std::string> &sentences,
std::string *serialized_model_proto = nullptr);
// Handy function to make a normalizer spec from the pre-compiled
// normalization name. Do not use this method in production as it crashes
// When `name` is invalid. Useful for unittesting.

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_trainer.h"
#include "filesystem.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "util.h"
@ -129,6 +130,19 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
while (fs->ReadLine(&line)) sentences.emplace_back(line);
}
ASSERT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--model_prefix=", model, " --vocab_size=1000"),
sentences)
.ok());
CheckVocab(model + ".model", 1000);
CheckNormalizer(model + ".model", true, false);
ASSERT_TRUE(SentencePieceTrainer::Train(
{{"model_prefix", model}, {"vocab_size", "1000"}}, sentences)
.ok());
CheckVocab(model + ".model", 1000);
CheckNormalizer(model + ".model", true, false);
VectorIterator it(std::move(sentences));
ASSERT_TRUE(
SentencePieceTrainer::Train(

View File

@ -575,7 +575,6 @@ END:
w.first = string_util::UnicodeTextToUTF8(uw2);
}
// +3 for meta pieces.
if (trainer_spec_.model_type() != TrainerSpec::WORD &&
trainer_spec_.model_type() != TrainerSpec::CHAR) {
CHECK_LE_OR_RETURN(

View File

@ -16,6 +16,7 @@
#include <atomic>
#include <iostream>
#include <memory>
namespace sentencepiece {
@ -187,8 +188,18 @@ std::mt19937 *GetRandomGenerator() {
}
#else
std::mt19937 *GetRandomGenerator() {
thread_local static std::mt19937 mt(GetRandomGeneratorSeed());
return &mt;
// Thread-locals occupy stack space in every thread ever created by the
// program, even if that thread never uses the thread-local variable.
//
// https://maskray.me/blog/2021-02-14-all-about-thread-local-storage
//
// sizeof(std::mt19937) is several kilobytes, so it is safer to put that on
// the heap, leaving only a pointer to it in thread-local storage. This must
// be a unique_ptr, not a raw pointer, so that the generator is not leaked on
// thread exit.
thread_local static auto mt =
std::make_unique<std::mt19937>(GetRandomGeneratorSeed());
return mt.get();
}
#endif
} // namespace random
@ -244,28 +255,40 @@ std::vector<std::string> StrSplitAsCSV(absl::string_view text) {
return result;
}
} // namespace util
#ifdef OS_WIN
namespace win32 {
std::wstring Utf8ToWide(absl::string_view input) {
int output_length = ::MultiByteToWideChar(
const int output_length = ::MultiByteToWideChar(
CP_UTF8, 0, input.data(), static_cast<int>(input.size()), nullptr, 0);
output_length = output_length <= 0 ? 0 : output_length - 1;
if (output_length == 0) {
return L"";
}
std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]);
std::fill(input_wide.get(), input_wide.get() + output_length + 1, L'\0');
std::wstring output(output_length, 0);
const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.data(),
static_cast<int>(input.size()),
input_wide.get(), output_length + 1);
std::wstring output;
if (result > 0) {
output.assign(input_wide.get());
}
return output;
output.data(), output.size());
return result == output_length ? output : L"";
}
} // namespace win32
#endif
} // namespace util
namespace log_domain {
double LogSum(const std::vector<double> &xs) {
if (xs.empty()) {
return -1.0 * std::numeric_limits<double>::max();
}
double sum = xs.front();
auto log_add = [](double xa, double xb) {
if (xa > xb) {
std::swap(xa, xb);
}
return xb + std::log1p(std::exp(xa - xb));
};
for (int i = 1; i < xs.size(); ++i) {
sum = log_add(sum, xs[i]);
}
return sum;
}
} // namespace log_domain
} // namespace sentencepiece

View File

@ -340,6 +340,10 @@ std::string StrError(int errnum);
std::vector<std::string> StrSplitAsCSV(absl::string_view text);
#ifdef OS_WIN
std::wstring Utf8ToWide(const absl::string_view input);
#endif
inline Status OkStatus() { return Status(); }
#define DECLARE_ERROR(FUNC) \
@ -428,5 +432,11 @@ class ThreadPool {
private:
std::vector<std::thread> tasks_;
};
namespace log_domain {
double LogSum(const std::vector<double> &xs);
} // namespace log_domain
} // namespace sentencepiece
#endif // UTIL_H_