support pretokenization in BPE mode.

This commit is contained in:
Taku Kudo 2023-04-11 06:48:08 +00:00
parent 119e58d97a
commit e07ebf74d7
3 changed files with 49 additions and 6 deletions

View File

@ -12,13 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model_trainer.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "bpe_model_trainer.h"
#include "pretokenizer_for_training.h"
#include "third_party/absl/container/flat_hash_set.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/absl/strings/str_replace.h"
#include "util.h"
namespace sentencepiece {
@ -189,6 +193,24 @@ util::Status Trainer::Train() {
SplitSentencesByWhitespace();
}
// Pretokenizer applied only in training time.
// Pretokenizer is used as a constraint of piece extractions.
const auto *pretokenizer = SentencePieceTrainer::GetPretokenizerForTraining();
if (pretokenizer || !trainer_spec_.pretokenization_delimiter().empty()) {
absl::string_view delimiter = trainer_spec_.pretokenization_delimiter();
LOG(INFO) << "Preprocessing with pretokenizer...";
for (auto &w : sentences_) {
if (pretokenizer) {
w.first = absl::StrJoin(pretokenizer->PreTokenize(w.first),
TrainerInterface::kUPPBoundaryStr);
} else if (!delimiter.empty()) {
w.first = absl::StrReplaceAll(
w.first, {{delimiter, TrainerInterface::kUPPBoundaryStr}});
}
}
}
// Initializes symbols_. symbols_[sid][i] stores an unary symbol.
symbols_.resize(sentences_.size());
for (size_t i = 0; i < sentences_.size(); ++i) {

View File

@ -83,8 +83,9 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
if (SentencePieceTrainer::GetPretokenizerForTraining() ||
!trainer_spec.pretokenization_delimiter().empty()) {
CHECK_EQ_OR_RETURN(TrainerSpec::UNIGRAM, trainer_spec.model_type())
<< "PretokenizerForTraining is only supported in UNIGRAM mode.";
CHECK_OR_RETURN(trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
trainer_spec.model_type() == TrainerSpec::BPE)
<< "PretokenizerForTraining is only supported in UNIGRAM or BPE mode.";
}
return util::OkStatus();
@ -231,7 +232,6 @@ bool TrainerInterface::IsValidSentencePiece(
if (c == 0x0000) { // NULL is not allowed for Darts (TRIE).
return false;
}
// kUPPBoundaryChar is included when split_by_upp_for_training is true.
if (c == kUPPBoundaryChar) {
return false;
}

View File

@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "trainer_interface.h"
#include <utility>
#include "filesystem.h"
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_format.h"
#include "trainer_interface.h"
#include "util.h"
namespace sentencepiece {
@ -72,7 +73,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_FALSE(IsValid("F1"));
EXPECT_FALSE(IsValid("1F"));
EXPECT_FALSE(IsValid("1A2"));
EXPECT_TRUE(IsValid("$10")); // $ and 1 are both "common" script.
EXPECT_TRUE(IsValid("$10")); // $ and 1 are both "common" script.
EXPECT_FALSE(IsValid("$ABC"));
EXPECT_FALSE(IsValid("ab\tbc")); // "\t" is UPP boundary.
EXPECT_FALSE(IsValid("ab cd"));
@ -113,6 +114,26 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_TRUE(IsValid("$10"));
EXPECT_TRUE(IsValid("$ABC"));
trainer_spec.set_split_by_unicode_script(true);
trainer_spec.set_split_by_number(true);
EXPECT_FALSE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10"));
trainer_spec.set_split_by_unicode_script(true);
trainer_spec.set_split_by_number(false);
EXPECT_TRUE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10"));
trainer_spec.set_split_by_unicode_script(false);
trainer_spec.set_split_by_number(true);
EXPECT_TRUE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10"));
trainer_spec.set_split_by_unicode_script(false);
trainer_spec.set_split_by_number(false);
EXPECT_TRUE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10"));
trainer_spec.set_max_sentencepiece_length(4);
EXPECT_TRUE(IsValid("1234"));
EXPECT_FALSE(IsValid("12345"));