make the error message more descriptive. null termnate string in Utf8ToWide

This commit is contained in:
Taku Kudo 2023-04-03 02:24:52 +00:00
parent 359c04397c
commit 573cc39aab
2 changed files with 15 additions and 13 deletions

View File

@ -760,19 +760,19 @@ util::Status TrainerInterface::InitMetaPieces() {
std::set<std::string> dup;
int id = 0;
auto insert_meta_symbol = [&id, &dup, this](
const std::string &w,
ModelProto::SentencePiece::Type type) -> bool {
auto insert_meta_symbol =
[&id, &dup, this](const std::string &w,
ModelProto::SentencePiece::Type type) -> util::Status {
if (!dup.insert(w).second) {
LOG(ERROR) << w << " is already defined.";
return false;
return util::InternalError(absl::StrCat(
w, " is already defined. duplicated symbols are not allowed."));
}
if (w == trainer_spec_.unk_piece()) {
LOG(ERROR) << trainer_spec_.unk_piece()
<< " must not be defined with --control_symbols and "
"--user_defined_symbols.";
return false;
return util::InternalError(
absl::StrCat(trainer_spec_.unk_piece(),
" must not be defined with --control_symbols and "
"--user_defined_symbols."));
}
if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) {
@ -785,21 +785,22 @@ util::Status TrainerInterface::InitMetaPieces() {
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
meta_pieces_[id] = std::make_pair(w, type);
}
return true;
return util::OkStatus();
};
for (const auto &w : trainer_spec_.control_symbols()) {
CHECK_OR_RETURN(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
RETURN_IF_ERROR(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
}
for (const auto &w : trainer_spec_.user_defined_symbols()) {
CHECK_OR_RETURN(
RETURN_IF_ERROR(
insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED));
}
if (trainer_spec_.byte_fallback()) {
for (int i = 0; i < 256; ++i) {
CHECK_OR_RETURN(
RETURN_IF_ERROR(
insert_meta_symbol(ByteToPiece(i), ModelProto::SentencePiece::BYTE));
}
}

View File

@ -256,6 +256,7 @@ std::wstring Utf8ToWide(absl::string_view input) {
return L"";
}
std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]);
std::fill(input_wide.get(), input_wide.get() + output_length + 1, L'\0');
const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.data(),
static_cast<int>(input.size()),
input_wide.get(), output_length + 1);