From 573cc39aabd70b25c6c0a80ce4d060a53f877597 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Mon, 3 Apr 2023 02:24:52 +0000 Subject: [PATCH] make the error message more descriptive. null termnate string in Utf8ToWide --- src/trainer_interface.cc | 27 ++++++++++++++------------- src/util.cc | 1 + 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 7270f29..fb4087a 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -760,19 +760,19 @@ util::Status TrainerInterface::InitMetaPieces() { std::set dup; int id = 0; - auto insert_meta_symbol = [&id, &dup, this]( - const std::string &w, - ModelProto::SentencePiece::Type type) -> bool { + auto insert_meta_symbol = + [&id, &dup, this](const std::string &w, + ModelProto::SentencePiece::Type type) -> util::Status { if (!dup.insert(w).second) { - LOG(ERROR) << w << " is already defined."; - return false; + return util::InternalError(absl::StrCat( + w, " is already defined. duplicated symbols are not allowed.")); } if (w == trainer_spec_.unk_piece()) { - LOG(ERROR) << trainer_spec_.unk_piece() - << " must not be defined with --control_symbols and " - "--user_defined_symbols."; - return false; + return util::InternalError( + absl::StrCat(trainer_spec_.unk_piece(), + " must not be defined with --control_symbols and " + "--user_defined_symbols.")); } if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) { @@ -785,21 +785,22 @@ util::Status TrainerInterface::InitMetaPieces() { while (meta_pieces_.find(id) != meta_pieces_.end()) ++id; meta_pieces_[id] = std::make_pair(w, type); } - return true; + + return util::OkStatus(); }; for (const auto &w : trainer_spec_.control_symbols()) { - CHECK_OR_RETURN(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL)); + RETURN_IF_ERROR(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL)); } for (const auto &w : trainer_spec_.user_defined_symbols()) { - CHECK_OR_RETURN( + RETURN_IF_ERROR( insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED)); } if (trainer_spec_.byte_fallback()) { for (int i = 0; i < 256; ++i) { - CHECK_OR_RETURN( + RETURN_IF_ERROR( insert_meta_symbol(ByteToPiece(i), ModelProto::SentencePiece::BYTE)); } } diff --git a/src/util.cc b/src/util.cc index c4c523c..538b00b 100644 --- a/src/util.cc +++ b/src/util.cc @@ -256,6 +256,7 @@ std::wstring Utf8ToWide(absl::string_view input) { return L""; } std::unique_ptr input_wide(new wchar_t[output_length + 1]); + std::fill(input_wide.get(), input_wide.get() + output_length + 1, L'\0'); const int result = ::MultiByteToWideChar(CP_UTF8, 0, input.data(), static_cast(input.size()), input_wide.get(), output_length + 1);