Merge pull request #38 from huggingface/wordpiece-training

Wordpiece training
This commit is contained in:
MOI Anthony 2020-01-03 19:56:33 -05:00 committed by GitHub
commit a1891387ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 559 additions and 54 deletions

View File

@ -0,0 +1,57 @@
import argparse
import glob
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, normalizers
parser = argparse.ArgumentParser()
parser.add_argument("--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept '**/*.txt' type of patterns \
if enclosed in quotes")
parser.add_argument("--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved")
parser.add_argument("--name",
default="bert-wordpiece",
type=str,
help="The name of the output vocab files")
args = parser.parse_args()
files = glob.glob(args.files)
if not files:
print(f"File does not exist: {args.files}")
exit(1)
# Initialize an empty tokenizer
tokenizer = Tokenizer(models.WordPiece.empty())
# Customize all the steps
tokenizer.with_normalizer(normalizers.BertNormalizer.new(
clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
))
tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new())
tokenizer.with_decoder(decoders.WordPiece.new())
# And then train
trainer = trainers.WordPieceTrainer.new(
vocab_size=50000,
min_frequency=2,
show_progress=True,
special_tokens=[ "<s>", "<unk>", "<pad>", "</s>" ],
limit_alphabet=1000,
continuing_subword_prefix="##"
)
tokenizer.train(trainer, files)
# Save the files
tokenizer.model.save(args.out, args.name)

View File

@ -40,7 +40,7 @@ trainer = trainers.BpeTrainer.new(
vocab_size=50000,
min_frequency=2,
show_progress=True,
special_tokens=[ "<s>", "<pad>", "</s" ],
special_tokens=[ "<s>", "<pad>", "</s>" ],
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
tokenizer.train(trainer, files)

View File

@ -18,6 +18,7 @@ use pyo3::wrap_pymodule;
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<trainers::Trainer>()?;
m.add_class::<trainers::BpeTrainer>()?;
m.add_class::<trainers::WordPieceTrainer>()?;
Ok(())
}

View File

@ -133,4 +133,11 @@ impl WordPiece {
}),
}
}
#[staticmethod]
fn empty() -> Model {
Model {
model: Container::Owned(Box::new(tk::models::wordpiece::WordPiece::default())),
}
}
}

View File

@ -1,5 +1,6 @@
extern crate tokenizers as tk;
use super::error::ToPyResult;
use super::utils::Container;
use pyo3::prelude::*;
use pyo3::types::*;
@ -13,39 +14,96 @@ pub struct Trainer {
pub struct BpeTrainer {}
#[pymethods]
impl BpeTrainer {
/// new(/vocab_size, min_frequency)
/// new(/ vocab_size, min_frequency)
/// --
///
/// Create a new BpeTrainer with the given configuration
#[staticmethod]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
let mut trainer = tk::models::bpe::BpeTrainer::default();
let mut builder = tk::models::bpe::BpeTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"vocab_size" => trainer.vocab_size = val.extract()?,
"min_frequency" => trainer.min_frequency = val.extract()?,
"show_progress" => trainer.show_progress = val.extract()?,
"special_tokens" => trainer.special_tokens = val.extract()?,
"limit_alphabet" => trainer.limit_alphabet = val.extract()?,
"vocab_size" => builder = builder.vocab_size(val.extract()?),
"min_frequency" => builder = builder.min_frequency(val.extract()?),
"show_progress" => builder = builder.show_progress(val.extract()?),
"special_tokens" => builder = builder.special_tokens(val.extract()?),
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
trainer.initial_alphabet = alphabet
.into_iter()
.map(|s| s.chars().nth(0))
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect();
builder = builder.initial_alphabet(
alphabet
.into_iter()
.map(|s| s.chars().nth(0))
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect(),
);
}
"continuing_subword_prefix" => {
builder = builder.continuing_subword_prefix(val.extract()?)
}
"end_of_word_suffix" => builder = builder.end_of_word_suffix(val.extract()?),
_ => println!("Ignored unknown kwargs option {}", key),
};
}
}
let trainer: PyResult<_> = ToPyResult(builder.build()).into();
Ok(Trainer {
trainer: Container::Owned(Box::new(trainer)),
trainer: Container::Owned(Box::new(trainer?)),
})
}
}
#[pyclass]
pub struct WordPieceTrainer {}
#[pymethods]
impl WordPieceTrainer {
/// new(/ vocab_size, min_frequency)
/// --
///
/// Create a new BpeTrainer with the given configuration
#[staticmethod]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
let mut builder = tk::models::wordpiece::WordPieceTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"vocab_size" => builder = builder.vocab_size(val.extract()?),
"min_frequency" => builder = builder.min_frequency(val.extract()?),
"show_progress" => builder = builder.show_progress(val.extract()?),
"special_tokens" => builder = builder.special_tokens(val.extract()?),
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
builder = builder.initial_alphabet(
alphabet
.into_iter()
.map(|s| s.chars().nth(0))
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect(),
);
}
"continuing_subword_prefix" => {
builder = builder.continuing_subword_prefix(val.extract()?)
}
"end_of_word_suffix" => builder = builder.end_of_word_suffix(val.extract()?),
_ => println!("Ignored unknown kwargs option {}", key),
};
}
}
let trainer: PyResult<_> = ToPyResult(builder.build()).into();
Ok(Trainer {
trainer: Container::Owned(Box::new(trainer?)),
})
}
}

View File

@ -1,4 +1,4 @@
use std::{convert::From, io};
use std::{convert::From, io, iter, mem};
mod cache;
mod model;
@ -62,6 +62,47 @@ impl std::error::Error for Error {
}
}
/// Provides access to the `FirstLastIterator` to any Iterator
pub(crate) trait WithFirstLastIterator: Iterator + Sized {
fn with_first_and_last(self) -> FirstLastIterator<Self>;
}
impl<I> WithFirstLastIterator for I
where
I: Iterator,
{
fn with_first_and_last(self) -> FirstLastIterator<Self> {
FirstLastIterator {
first: true,
iter: self.peekable(),
}
}
}
/// Provides information about whether an item is the first and/or the last of the iterator
pub(crate) struct FirstLastIterator<I>
where
I: Iterator,
{
first: bool,
iter: iter::Peekable<I>,
}
impl<I> Iterator for FirstLastIterator<I>
where
I: Iterator,
{
/// (is_first, is_last, item)
type Item = (bool, bool, I::Item);
fn next(&mut self) -> Option<Self::Item> {
let first = mem::replace(&mut self.first, false);
self.iter
.next()
.map(|e| (first, self.iter.peek().is_none(), e))
}
}
// Re-export
pub use cache::*;
pub use model::*;

View File

@ -1,4 +1,4 @@
use super::{Cache, Error, Pair, Word};
use super::{Cache, Error, Pair, WithFirstLastIterator, Word};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde_json::Value;
@ -18,6 +18,8 @@ struct Config {
cache_capacity: Option<usize>,
dropout: Option<f32>,
unk_token: Option<u32>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
@ -61,6 +63,18 @@ impl BpeBuilder {
self
}
/// Set the `continuing_subword_prefix` option.
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(prefix);
self
}
/// Set the `end_of_word_suffix` option.
pub fn end_of_word_suffix(mut self, prefix: String) -> Self {
self.config.end_of_word_suffix = Some(prefix);
self
}
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(self) -> Result<BPE> {
// Validate dropout.
@ -92,6 +106,8 @@ impl BpeBuilder {
cache,
dropout: self.config.dropout,
unk_token: self.config.unk_token,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
})
}
}
@ -111,6 +127,10 @@ pub struct BPE {
dropout: Option<f32>,
/// The unknown token to be used when we encounter an unknown char
unk_token: Option<u32>,
/// An optional prefix to use on any subword that exist only behind another one
continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
end_of_word_suffix: Option<String>,
}
impl Default for BPE {
@ -134,6 +154,8 @@ impl Clone for BPE {
cache: fresh_cache,
dropout: self.dropout,
unk_token: self.unk_token,
continuing_subword_prefix: self.continuing_subword_prefix.clone(),
end_of_word_suffix: self.end_of_word_suffix.clone(),
}
}
}
@ -215,10 +237,37 @@ impl BPE {
}
}
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.vocab
}
pub fn get_unk_token(&self) -> &Option<u32> {
&self.unk_token
}
pub fn get_continuing_subword_prefix(&self) -> &Option<String> {
&self.continuing_subword_prefix
}
fn merge_word(&self, w: &str) -> Word {
let mut word = Word::new();
for c in w.chars() {
if let Some(id) = self.vocab.get(&c.to_string()) {
for (is_first, is_last, c) in w.chars().with_first_and_last() {
let mut s = c.to_string();
// Add the `continuing_subword_prefix` if relevant
if !is_first {
if let Some(prefix) = &self.continuing_subword_prefix {
s = format!("{}{}", prefix, s);
}
}
// Add the `end_of_word_suffix` if relevant
if is_last {
if let Some(suffix) = &self.end_of_word_suffix {
s = format!("{}{}", s, suffix);
}
}
if let Some(id) = self.vocab.get(&s) {
word.add(*id);
} else if let Some(unk) = &self.unk_token {
// Handle UNK token

View File

@ -1,10 +1,108 @@
#![allow(clippy::map_entry)]
use super::{Pair, Word, BPE};
use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::tokenizer::{Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct Config {
min_frequency: Option<u32>,
vocab_size: Option<usize>,
show_progress: Option<bool>,
special_tokens: Option<Vec<String>>,
limit_alphabet: Option<usize>,
initial_alphabet: Option<HashSet<char>>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
#[derive(Default)]
pub struct BpeTrainerBuilder {
config: Config,
}
impl BpeTrainerBuilder {
/// Constructs a new `BpeTrainerBuilder`
pub fn new() -> Self {
Self::default()
}
/// Set the expected minimum frequency
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.config.min_frequency = Some(frequency);
self
}
/// Set the vocabulary size
pub fn vocab_size(mut self, size: usize) -> Self {
self.config.vocab_size = Some(size);
self
}
/// Set whether to show progress
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = Some(show);
self
}
/// Set the special tokens
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
self.config.special_tokens = Some(tokens);
self
}
/// Set whether to limit the alphabet
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.config.limit_alphabet = Some(limit);
self
}
/// Set the initial alphabet
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.config.initial_alphabet = Some(alphabet);
self
}
/// Set the continuing_subword_prefix
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(prefix);
self
}
/// Set the end_of_word_suffix
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.config.end_of_word_suffix = Some(suffix);
self
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<BpeTrainer> {
let mut trainer = BpeTrainer::default();
if let Some(freq) = self.config.min_frequency {
trainer.min_frequency = freq;
}
if let Some(vocab_size) = self.config.vocab_size {
trainer.vocab_size = vocab_size;
}
if let Some(show) = self.config.show_progress {
trainer.show_progress = show;
}
if let Some(special_tokens) = self.config.special_tokens {
trainer.special_tokens = special_tokens;
}
if let Some(alphabet) = self.config.initial_alphabet {
trainer.initial_alphabet = alphabet;
}
trainer.limit_alphabet = self.config.limit_alphabet;
trainer.continuing_subword_prefix = self.config.continuing_subword_prefix;
trainer.end_of_word_suffix = self.config.end_of_word_suffix;
Ok(trainer)
}
}
/// In charge of training a BPE model from a mapping of words to word counts.
///
/// # Examples
@ -22,12 +120,23 @@ use std::collections::{HashMap, HashSet};
/// let model = trainer.train(word_counts);
/// ```
pub struct BpeTrainer {
pub min_frequency: u32,
pub vocab_size: usize,
pub show_progress: bool,
pub special_tokens: Vec<String>,
pub limit_alphabet: Option<usize>,
pub initial_alphabet: HashSet<char>,
/// The minimum frequency a pair must have to produce a merge operation
min_frequency: u32,
/// The target vocabulary size
vocab_size: usize,
/// Whether to show progress while training
show_progress: bool,
/// A list of special tokens that the model should know of
special_tokens: Vec<String>,
/// Whether to limit the number of initial tokens that can be kept before computing merges
limit_alphabet: Option<usize>,
/// The initial alphabet we want absolutely to include. This allows to cover
/// some characters that are not necessarily in the training set
initial_alphabet: HashSet<char>,
/// An optional prefix to use on any subword that exist only behind another one
continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
end_of_word_suffix: Option<String>,
}
impl Default for BpeTrainer {
@ -39,6 +148,8 @@ impl Default for BpeTrainer {
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
}
}
}
@ -52,6 +163,10 @@ impl BpeTrainer {
}
}
pub fn builder() -> BpeTrainerBuilder {
BpeTrainerBuilder::new()
}
/// Setup a progress bar if asked to show progress
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {
@ -152,13 +267,60 @@ impl BpeTrainer {
}
});
}
}
impl Trainer for BpeTrainer {
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
/// Tokenize words and add subwords to the vocabulary when relevant
fn tokenize_words(
&self,
wc: &HashMap<String, u32>,
w2id: &mut HashMap<String, u32>,
id2w: &mut Vec<String>,
p: &Option<ProgressBar>,
) -> (Vec<Word>, Vec<i32>) {
let mut words: Vec<Word> = vec![];
let mut counts: Vec<i32> = vec![];
for (word, count) in wc {
let mut current_word = Word::new();
counts.push(*count as i32);
for (is_first, is_last, c) in word.chars().with_first_and_last() {
let mut s = c.to_string();
//if let Some(id) = word_to_id.get(&s) {
if w2id.contains_key(&s) {
// Found the initial char in the authorized alphabet
// Add the `continuing_subword_prefix` if relevant
if !is_first {
if let Some(prefix) = &self.continuing_subword_prefix {
s = format!("{}{}", prefix, s);
}
}
// Add the `end_of_word_suffix` if relevant
if is_last {
if let Some(suffix) = &self.end_of_word_suffix {
s = format!("{}{}", s, suffix);
}
}
// Insert the new formed string if necessary
if !w2id.contains_key(&s) {
id2w.push(s.clone());
w2id.insert(s.clone(), (id2w.len() - 1) as u32);
}
current_word.add(w2id[&s]);
}
}
words.push(current_word);
if let Some(p) = p {
p.inc(1);
}
}
(words, counts)
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
let mut word_to_id: HashMap<String, u32> = HashMap::new();
let mut id_to_word: Vec<String> = vec![];
@ -178,22 +340,8 @@ impl Trainer for BpeTrainer {
// 3. Tokenize words
//
self.update_progress(&progress, word_counts.len(), "Tokenize words");
for (word, count) in &word_counts {
let mut current_word = Word::new();
counts.push(*count as i32);
for c in word.chars() {
let s = c.to_string();
if let Some(id) = word_to_id.get(&s) {
current_word.add(*id);
}
}
words.push(current_word);
if let Some(p) = &progress {
p.inc(1);
}
}
let (mut words, counts) =
self.tokenize_words(&word_counts, &mut word_to_id, &mut id_to_word, &progress);
self.finalize_progress(&progress, words.len());
//
@ -258,10 +406,16 @@ impl Trainer for BpeTrainer {
break;
}
let new_token = format!(
"{}{}",
id_to_word[best_pair.0 as usize], id_to_word[best_pair.1 as usize]
);
// Build new token
let part_a = &id_to_word[best_pair.0 as usize];
let mut part_b = id_to_word[best_pair.1 as usize].to_owned();
if let Some(prefix) = &self.continuing_subword_prefix {
if part_b.starts_with(prefix) {
let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum();
part_b = part_b[prefix_byte_len..].to_string();
}
}
let new_token = format!("{}{}", part_a, part_b);
// Insert new token
let new_token_id = id_to_word.len() as u32;
@ -306,14 +460,22 @@ impl Trainer for BpeTrainer {
}
self.finalize_progress(&progress, merges.len());
Ok(Box::new(BPE::new(
Ok(BPE::new(
word_to_id,
merges
.into_iter()
.enumerate()
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
.collect(),
)))
))
}
}
impl Trainer for BpeTrainer {
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let bpe = self.train(word_counts)?;
Ok(Box::new(bpe))
}
/// Process a bunch of tokens, counting them

View File

@ -1,3 +1,4 @@
use crate::models::bpe::BPE;
use crate::tokenizer::{Model, Offsets, Result, Token};
use std::{
collections::HashMap,
@ -8,6 +9,9 @@ use std::{
path::{Path, PathBuf},
};
mod trainer;
pub use trainer::*;
#[derive(Debug)]
pub enum Error {
MissingUnkToken,
@ -27,6 +31,7 @@ impl fmt::Display for Error {
pub struct WordPiece {
unk_token: String,
continuing_subword_prefix: String,
max_input_chars_per_word: usize,
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
@ -37,6 +42,7 @@ impl Default for WordPiece {
WordPiece {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
continuing_subword_prefix: String::from("##"),
unk_token: String::from("[UNK]"),
max_input_chars_per_word: 100,
}
@ -63,8 +69,35 @@ impl WordPiece {
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
unk_token,
max_input_chars_per_word: max_input_chars_per_word.unwrap_or(100),
..Default::default()
})
}
pub fn from_bpe(bpe: &BPE) -> Self {
let vocab = bpe.get_vocab().clone();
let vocab_r = vocab
.clone()
.into_iter()
.map(|(token, id)| (id, token))
.collect();
let mut wp = WordPiece {
vocab,
vocab_r,
..Default::default()
};
if let Some(unk) = bpe.get_unk_token() {
if let Some(unk_token) = wp.vocab_r.get(&unk) {
wp.unk_token = unk_token.to_owned();
}
}
if let Some(prefix) = bpe.get_continuing_subword_prefix() {
wp.continuing_subword_prefix = prefix.to_owned();
}
wp
}
}
impl Model for WordPiece {
@ -101,7 +134,7 @@ impl Model for WordPiece {
while start < end {
let mut substr = chars[start..end].iter().collect::<String>();
if start > 0 {
substr = format!("##{}", substr);
substr = format!("{}{}", self.continuing_subword_prefix, substr);
}
if self.vocab.contains_key(&substr) {
cur_str = Some(Token {
@ -159,7 +192,7 @@ impl Model for WordPiece {
vocab_file.write_all(
&vocab
.into_iter()
.map(|(token, _)| token.as_bytes().to_owned())
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.flatten()
.collect::<Vec<_>>()[..],
)?;

View File

@ -0,0 +1,97 @@
use super::WordPiece;
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder};
use crate::tokenizer::{Model, Result, Trainer};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct WordPieceTrainerBuilder {
bpe_trainer_builder: BpeTrainerBuilder,
}
impl WordPieceTrainerBuilder {
/// Constructs a new `WordPieceTrainerBuilder`
pub fn new() -> Self {
Self::default()
}
/// Set the expected minimum frequency
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
self
}
/// Set the vocabulary size
pub fn vocab_size(mut self, size: usize) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
self
}
/// Set whether to show progress
pub fn show_progress(mut self, show: bool) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
self
}
/// Set the special tokens
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
self
}
/// Set whether to limit the alphabet
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
self
}
/// Set the initial alphabet
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
self
}
/// Set the continuing_subword_prefix
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
self
}
/// Set the end_of_word_suffix
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
self
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<WordPieceTrainer> {
let bpe_trainer = self.bpe_trainer_builder.build()?;
Ok(WordPieceTrainer { bpe_trainer })
}
}
#[derive(Default)]
pub struct WordPieceTrainer {
bpe_trainer: BpeTrainer,
}
impl WordPieceTrainer {
pub fn builder() -> WordPieceTrainerBuilder {
WordPieceTrainerBuilder::default()
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
let bpe = self.bpe_trainer.train(word_counts)?;
Ok(WordPiece::from_bpe(&bpe))
}
}
impl Trainer for WordPieceTrainer {
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let wp = self.train(word_counts)?;
Ok(Box::new(wp))
}
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.bpe_trainer.process_tokens(&mut words, tokens)
}
}