make the error message more descriptive. null termnate string in Utf8ToWide

This commit is contained in:
Taku Kudo 2023-04-03 02:24:52 +00:00
parent 359c04397c
commit 573cc39aab
2 changed files with 15 additions and 13 deletions

View File

@ -760,19 +760,19 @@ util::Status TrainerInterface::InitMetaPieces() {
std::set<std::string> dup; std::set<std::string> dup;
int id = 0; int id = 0;
auto insert_meta_symbol = [&id, &dup, this]( auto insert_meta_symbol =
const std::string &w, [&id, &dup, this](const std::string &w,
ModelProto::SentencePiece::Type type) -> bool { ModelProto::SentencePiece::Type type) -> util::Status {
if (!dup.insert(w).second) { if (!dup.insert(w).second) {
LOG(ERROR) << w << " is already defined."; return util::InternalError(absl::StrCat(
return false; w, " is already defined. duplicated symbols are not allowed."));
} }
if (w == trainer_spec_.unk_piece()) { if (w == trainer_spec_.unk_piece()) {
LOG(ERROR) << trainer_spec_.unk_piece() return util::InternalError(
<< " must not be defined with --control_symbols and " absl::StrCat(trainer_spec_.unk_piece(),
"--user_defined_symbols."; " must not be defined with --control_symbols and "
return false; "--user_defined_symbols."));
} }
if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) { if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) {
@ -785,21 +785,22 @@ util::Status TrainerInterface::InitMetaPieces() {
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id; while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
meta_pieces_[id] = std::make_pair(w, type); meta_pieces_[id] = std::make_pair(w, type);
} }
return true;
return util::OkStatus();
}; };
for (const auto &w : trainer_spec_.control_symbols()) { for (const auto &w : trainer_spec_.control_symbols()) {
CHECK_OR_RETURN(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL)); RETURN_IF_ERROR(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
} }
for (const auto &w : trainer_spec_.user_defined_symbols()) { for (const auto &w : trainer_spec_.user_defined_symbols()) {
CHECK_OR_RETURN( RETURN_IF_ERROR(
insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED)); insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED));
} }
if (trainer_spec_.byte_fallback()) { if (trainer_spec_.byte_fallback()) {
for (int i = 0; i < 256; ++i) { for (int i = 0; i < 256; ++i) {
CHECK_OR_RETURN( RETURN_IF_ERROR(
insert_meta_symbol(ByteToPiece(i), ModelProto::SentencePiece::BYTE)); insert_meta_symbol(ByteToPiece(i), ModelProto::SentencePiece::BYTE));
} }
} }

View File

@ -256,6 +256,7 @@ std::wstring Utf8ToWide(absl::string_view input) {
return L""; return L"";
} }
std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]); std::unique_ptr<wchar_t[]> input_wide(new wchar_t[output_length + 1]);
std::fill(input_wide.get(), input_wide.get() + output_length + 1, L'\0');
const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.data(), const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.data(),
static_cast<int>(input.size()), static_cast<int>(input.size()),
input_wide.get(), output_length + 1); input_wide.get(), output_length + 1);