diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 37a6652..666f4a4 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -140,8 +140,8 @@ jobs: --test nllb --features download-libtorch - test-onnx: - name: Integration tests (ONNX models) + test-opt-features: + name: Integration tests (Optional features) runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -155,7 +155,9 @@ jobs: command: test args: --package rust-bert --features onnx + --features hf-tokenizers --test onnx + --test hf_tokenizers --features download-libtorch convert-model: diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bbaf7b..072fe48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,8 @@ All notable changes to this project will be documented in this file. The format ## [Unreleased] ## Added -- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines +- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines. +- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files. ## Fixed - (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences). diff --git a/Cargo.toml b/Cargo.toml index 71520c4..5de1256 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ download-libtorch = ["tch/download-libtorch"] onnx = ["ort", "ndarray"] rustls-tls = ["cached-path/rustls-tls"] default-tls = ["cached-path/default-tls"] +hf-tokenizers = ["tokenizers"] [package.metadata.docs.rs] features = ["doc-only"] @@ -89,6 +90,7 @@ dirs = { version = "4", optional = true } lazy_static = { version = "1", optional = true } ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]} ndarray = {version="0.15", optional = true} +tokenizers = {version="0.13.3", optional=true, default-features = false, features = ["onig"]} [dev-dependencies] anyhow = "1" diff --git a/examples/generation_gpt2_hf_tokenizers.rs b/examples/generation_gpt2_hf_tokenizers.rs new file mode 100644 index 0000000..7629f74 --- /dev/null +++ b/examples/generation_gpt2_hf_tokenizers.rs @@ -0,0 +1,61 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::pipelines::common::{ModelType, TokenizerOption}; +use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; +use rust_bert::resources::{RemoteResource, ResourceProvider}; +use std::fs::File; +use std::io::Write; +use tempfile::TempDir; + +fn main() -> anyhow::Result<()> { + // Set-up model + let generate_config = TextGenerationConfig { + model_type: ModelType::GPT2, + max_length: Some(30), + do_sample: false, + num_beams: 1, + temperature: 1.0, + num_return_sequences: 1, + ..Default::default() + }; + + // Create tokenizer + let tmp_dir = TempDir::new()?; + let special_token_map_path = tmp_dir.path().join("special_token_map.json"); + let mut tmp_file = File::create(&special_token_map_path)?; + writeln!( + tmp_file, + r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"# + )?; + + let tokenizer_path = RemoteResource::from_pretrained(( + "gpt2/tokenizer", + "https://huggingface.co/gpt2/resolve/main/tokenizer.json", + )) + .get_local_path()?; + let tokenizer = + TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?; + + let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?; + + let input_context = "The dog"; + // let second_input_context = "The cat was"; + let output = model.generate(&[input_context], None); + + for sentence in output { + println!("{sentence:?}"); + } + Ok(()) +} diff --git a/src/models/longt5/attention.rs b/src/models/longt5/attention.rs index e3cc116..5f73817 100644 --- a/src/models/longt5/attention.rs +++ b/src/models/longt5/attention.rs @@ -114,7 +114,9 @@ fn make_global_fixed_block_ids( attention_mask: &Tensor, global_block_size: i64, ) -> (Tensor, Tensor) { - let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {unreachable!()}; + let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else { + unreachable!() + }; let handle_orphan_tokens = |block_ids: Tensor| -> Tensor { let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device())) diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index c90121e..5356671 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -56,13 +56,18 @@ use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; + use std::fmt::Debug; + use std::path::{Path, PathBuf}; use tch::{Device, Kind, Tensor}; #[cfg(feature = "onnx")] use crate::pipelines::onnx::ONNXModelConfig; +#[cfg(feature = "hf-tokenizers")] +use crate::pipelines::hf_tokenizers::HFTokenizer; + #[derive(Debug, Default)] /// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past) pub struct ONNXModelResources { @@ -288,6 +293,9 @@ pub enum TokenizerOption { FNet(FNetTokenizer), /// Bart Tokenizer Bart(RobertaTokenizer), + /// HF Tokenizer + #[cfg(feature = "hf-tokenizers")] + HFTokenizer(HFTokenizer), } impl ConfigOption { @@ -913,28 +921,13 @@ impl TokenizerOption { Ok(tokenizer) } - /// Returns the model type - pub fn model_type(&self) -> ModelType { - match *self { - Self::Bert(_) => ModelType::Bert, - Self::Deberta(_) => ModelType::Deberta, - Self::DebertaV2(_) => ModelType::DebertaV2, - Self::Roberta(_) => ModelType::Roberta, - Self::Bart(_) => ModelType::Bart, - Self::XLMRoberta(_) => ModelType::XLMRoberta, - Self::Marian(_) => ModelType::Marian, - Self::T5(_) => ModelType::T5, - Self::Albert(_) => ModelType::Albert, - Self::XLNet(_) => ModelType::XLNet, - Self::GPT2(_) => ModelType::GPT2, - Self::OpenAiGpt(_) => ModelType::OpenAiGpt, - Self::Reformer(_) => ModelType::Reformer, - Self::ProphetNet(_) => ModelType::ProphetNet, - Self::Pegasus(_) => ModelType::Pegasus, - Self::MBart50(_) => ModelType::MBart, - Self::M2M100(_) | Self::NLLB(_) => ModelType::M2M100, - Self::FNet(_) => ModelType::FNet, - } + #[cfg(feature = "hf-tokenizers")] + pub fn from_hf_tokenizer_file, S: AsRef>( + tokenizer_file: P, + special_token_map: S, + ) -> Result { + let hf_tokenizer = HFTokenizer::from_file(tokenizer_file, special_token_map)?; + Ok(TokenizerOption::HFTokenizer(hf_tokenizer)) } /// Interface method @@ -946,7 +939,7 @@ impl TokenizerOption { stride: usize, ) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { match *self { Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list( @@ -1082,6 +1075,8 @@ impl TokenizerOption { truncation_strategy, stride, ), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.encode_list(text_list).unwrap(), } } @@ -1227,6 +1222,8 @@ impl TokenizerOption { truncation_strategy, stride, ), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair_list(text_pair_list).unwrap(), } } @@ -1297,6 +1294,8 @@ impl TokenizerOption { Self::FNet(ref tokenizer) => { tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair(text_1, text_2).unwrap(), } } @@ -1322,6 +1321,8 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => tokenizer.tokenize(text), Self::NLLB(ref tokenizer) => tokenizer.tokenize(text), Self::FNet(ref tokenizer) => tokenizer.tokenize(text), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize(text), } } @@ -1347,13 +1348,15 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::NLLB(ref tokenizer) => tokenizer.tokenize_with_offsets(text), Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_with_offsets(text), } } /// Interface method to tokenization pub fn tokenize_list(&self, text: &[S]) -> Vec> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { match *self { Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), @@ -1383,6 +1386,8 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_list(text), } } @@ -1451,6 +1456,8 @@ impl TokenizerOption { Self::FNet(ref tokenizer) => { tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.decode(token_ids, skip_special_tokens), } } @@ -1537,6 +1544,13 @@ impl TokenizerOption { token_ids_with_offsets_1, token_ids_with_offsets_2, ), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => { + return tokenizer.build_input_with_special_tokens( + token_ids_with_offsets_1, + token_ids_with_offsets_2, + ) + } }; TokenizedInput { token_ids: token_ids_with_special_tokens.token_ids, @@ -1736,6 +1750,8 @@ impl TokenizerOption { Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::NLLB(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), } } @@ -1818,6 +1834,10 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); vocab.token_to_id(vocab.get_unknown_value()) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => { + tokenizer.token_to_id(&tokenizer.special_token_map.unk_token) + } } } @@ -1888,6 +1908,12 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); Some(vocab.token_to_id(vocab.get_pad_value())) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer + .special_token_map + .pad_token + .as_ref() + .map(|token| tokenizer.token_to_id(token)), Self::Reformer(_) => None, Self::GPT2(_) => None, Self::OpenAiGpt(_) => None, @@ -1949,6 +1975,12 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); Some(vocab.token_to_id(vocab.get_sep_value())) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer + .special_token_map + .sep_token + .as_ref() + .map(|token| tokenizer.token_to_id(token)), Self::Marian(_) => None, Self::T5(_) => None, Self::GPT2(_) => None, @@ -2009,6 +2041,12 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); Some(vocab.token_to_id(vocab.get_mask_value())) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer + .special_token_map + .mask_token + .as_ref() + .map(|token| tokenizer.token_to_id(token)), Self::Marian(_) => None, Self::M2M100(_) => None, Self::NLLB(_) => None, @@ -2058,6 +2096,8 @@ impl TokenizerOption { Self::Pegasus(ref tokenizer) => { Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value()) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer.special_token_map.mask_token.as_deref(), Self::M2M100(_) => None, Self::NLLB(_) => None, Self::Marian(_) => None, @@ -2111,6 +2151,12 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); Some(vocab.token_to_id(vocab.get_bos_value())) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer + .special_token_map + .bos_token + .as_ref() + .map(|token| tokenizer.token_to_id(token)), Self::MBart50(_) => Some(0), Self::FNet(_) => None, Self::Bert(_) => None, @@ -2186,6 +2232,12 @@ impl TokenizerOption { let vocab = MultiThreadedTokenizer::vocab(tokenizer); Some(vocab.token_to_id(vocab.get_eos_value())) } + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref tokenizer) => tokenizer + .special_token_map + .eos_token + .as_ref() + .map(|token| tokenizer.token_to_id(token)), Self::FNet(_) => None, Self::Bert(_) => None, Self::ProphetNet(_) => None, @@ -2264,6 +2316,8 @@ impl TokenizerOption { Self::M2M100(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids), Self::NLLB(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids), Self::FNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids), } } @@ -2289,6 +2343,8 @@ impl TokenizerOption { Self::M2M100(ref mut tokenizer) => tokenizer.add_tokens(tokens), Self::NLLB(ref mut tokenizer) => tokenizer.add_tokens(tokens), Self::FNet(ref mut tokenizer) => tokenizer.add_tokens(tokens), + #[cfg(feature = "hf-tokenizers")] + Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_tokens(tokens), } } } diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 8869dec..a6cc389 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -381,7 +381,7 @@ pub(crate) mod private_generation_utils { pad_token_id: Option, ) -> Tensor where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let token_ids = if self.is_encoder_decoder() { let tokens = self._get_tokenizer().encode_list( @@ -1774,7 +1774,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator { generate_options: Option, ) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let indices_outputs = self.generate_indices(prompt_texts, generate_options); let mut output = Vec::with_capacity(indices_outputs.len()); @@ -1868,7 +1868,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator { generate_options: Option, ) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let eos_token_ids = self.get_eos_ids(); diff --git a/src/pipelines/hf_tokenizers.rs b/src/pipelines/hf_tokenizers.rs new file mode 100644 index 0000000..d93302b --- /dev/null +++ b/src/pipelines/hf_tokenizers.rs @@ -0,0 +1,764 @@ +/// # Support for [tokenizers](https://github.com/huggingface/tokenizers) +/// +/// This module implements interface methods to allow loading tokenizers trained and implemented with +/// the [Tokenizers](https://github.com/huggingface/tokenizers) crate. While the functionality of these tokenizers +/// is expected to be identical to the default [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers) used +/// in this crate, the implementation and input file format differs. +/// +/// Because some of the logic related to the special token handling is implemented at the Python level using the rust bindings, +/// the proposed implementation requires two files to be provided: +/// - `tokenizer.json` containing the tokenizer model, pre- and post-processing options and vocabulary +/// - `special_token_map.json` containing a mapping of the special tokens used by the model (e.g. BOS and CLS values) +use crate::RustBertError; +use rust_tokenizers::{ + Mask, Offset, OffsetSize, TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets, +}; +use serde::{de, Deserialize, Deserializer}; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use tokenizers::tokenizer::Tokenizer as HFBaseTokenizer; +use tokenizers::{AddedToken, EncodeInput, Encoding, InputSequence}; + +impl From for RustBertError { + fn from(error: tokenizers::tokenizer::Error) -> Self { + RustBertError::TokenizerError(error.to_string()) + } +} + +/// Container for a special token map to be deserialized from a `special_token_map.json` +#[derive(Debug, Default, Clone, Deserialize)] +pub struct SpecialTokenMap { + /// Unknown token (must be provided for all tokenizers) + pub unk_token: String, + /// Optional padding token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub pad_token: Option, + /// Optional bos token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub bos_token: Option, + /// Optional sep token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub sep_token: Option, + /// Optional cls token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub cls_token: Option, + /// Optional eos token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub eos_token: Option, + /// Optional mask token + #[serde(default)] + #[serde(deserialize_with = "string_or_added_token_struct")] + pub mask_token: Option, + /// Optional additional special tokens + pub additional_special_tokens: Option>, +} + +/// Deserialization utility function for `special_token_map.json` to read nested special tokens structure +fn string_or_added_token_struct<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct StringOrStruct; + + impl<'de> de::Visitor<'de> for StringOrStruct { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string or map") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(Some(value.to_string())) + } + + fn visit_map(self, mut map: M) -> Result + where + M: de::MapAccess<'de>, + { + let mut value = None; + while let Some(key) = map.next_key::()? { + if key == "content" { + value = Some(map.next_value::()?); + } else { + _ = map.next_value::(); + } + } + Ok(value) + } + } + deserializer.deserialize_any(StringOrStruct) +} + +/// Base class for a tokenizer from the Tokenizers library +pub struct HFTokenizer { + /// Base tokenizer object + tokenizer: HFBaseTokenizer, + /// Special token map + pub(crate) special_token_map: SpecialTokenMap, +} + +impl HFTokenizer { + /// Create a new tokenizer from a file. + /// + /// # Arguments + /// - `tokenizer_file` path to location containing the tokenizer model, pre- and post-processing options and vocabulary + /// - `special_token_map` path to location containing a mapping of the special tokens used by the model (e.g. BOS and CLS values) + /// + /// # Returns + /// - Wrapper around a tokenizer that can be loaded in a `TokenizerOption` in this crate + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// use std::path::PathBuf; + /// let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// # Ok(()) + /// # } + /// ``` + pub fn from_file, S: AsRef>( + tokenizer_file: P, + special_token_map: S, + ) -> Result { + let tokenizer = HFBaseTokenizer::from_file(tokenizer_file)?; + let f = File::open(&special_token_map).map_err(|e| { + RustBertError::IOError(format!( + "{} special token map file not found :{}", + special_token_map.as_ref().display(), + e + )) + })?; + let br = BufReader::new(f); + let special_token_map = serde_json::from_reader(br).map_err(|e| { + RustBertError::IOError(format!("Invalid special token mapping file {e}")) + })?; + Ok(Self { + tokenizer, + special_token_map, + }) + } + + fn encoding_to_tokenized_input(encoding: Encoding) -> TokenizedInput { + let token_ids = encoding + .get_ids() + .iter() + .map(|token_id| *token_id as i64) + .collect(); + let segment_ids = encoding + .get_type_ids() + .iter() + .map(|segment_id| *segment_id as i8) + .collect(); + let special_tokens_mask = encoding + .get_special_tokens_mask() + .iter() + .map(|segment_id| *segment_id as i8) + .collect(); + let overflowing_tokens: Vec = encoding + .get_overflowing() + .iter() + .flat_map(|encoding| encoding.get_ids()) + .map(|token_id| *token_id as i64) + .collect(); + let num_truncated_tokens = overflowing_tokens.len(); + let token_offsets = encoding + .get_offsets() + .iter() + .map(|offset| { + Some(Offset { + begin: offset.0 as OffsetSize, + end: offset.1 as OffsetSize, + }) + }) + .collect(); + let reference_offsets = encoding + .get_offsets() + .iter() + .map(|offset| (offset.0 as OffsetSize..offset.1 as OffsetSize).collect()) + .collect(); + let mask = encoding + .get_special_tokens_mask() + .iter() + .map(|segment_id| { + if *segment_id == 0 { + Mask::None + } else { + Mask::Special + } + }) + .collect(); + TokenizedInput { + token_ids, + segment_ids, + special_tokens_mask, + overflowing_tokens, + num_truncated_tokens, + token_offsets, + reference_offsets, + mask, + } + } + + /// Encode a list of texts + /// + /// # Arguments + /// - `text_list` slice of string-like inputs to encode + /// + /// # Returns + /// - `Vec` containing the tokenized and encoded texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let texts = &["first text to encode", "second text to encode"]; + /// let output = tokenizer.encode_list(texts); + /// # Ok(()) + /// # } + /// ``` + pub fn encode_list(&self, text_list: &[S]) -> Result, RustBertError> + where + S: AsRef + Sync + Send, + { + let encoding_inputs = text_list.iter().map(|text| text.as_ref()).collect(); + let encodings = self.tokenizer.encode_batch(encoding_inputs, true)?; + let mut tokenized_inputs: Vec = Vec::with_capacity(encodings.len()); + for encoding in encodings { + tokenized_inputs.push(Self::encoding_to_tokenized_input(encoding)); + } + + Ok(tokenized_inputs) + } + + /// Encode a list of text pairs + /// + /// This is used for application where the model takes 2 input sequences as an input (e.g. natural language inference). + /// + /// # Arguments + /// - `text_pair_list` slice of tuples of string-like inputs to encode + /// + /// # Returns + /// - `Vec` containing the tokenized and encoded texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let texts = &[ + /// ( + /// "first text of first pair to encode", + /// "second text of first pair to encode", + /// ), + /// ( + /// "first text of second pair to encode", + /// "second text of second pair to encode", + /// ), + /// ]; + /// let output = tokenizer.encode_pair_list(texts); + /// # Ok(()) + /// # } + /// ``` + pub fn encode_pair_list( + &self, + text_pair_list: &[(&str, &str)], + ) -> Result, RustBertError> { + let encoding_inputs: Vec = text_pair_list + .iter() + .map(|(text_1, text_2)| { + EncodeInput::Dual( + InputSequence::Raw(Cow::Borrowed(text_1)), + InputSequence::Raw(Cow::Borrowed(text_2)), + ) + }) + .collect(); + let encodings = self.tokenizer.encode_batch(encoding_inputs, true)?; + let mut tokenized_inputs: Vec = Vec::with_capacity(encodings.len()); + for encoding in encodings { + tokenized_inputs.push(Self::encoding_to_tokenized_input(encoding)); + } + + Ok(tokenized_inputs) + } + + /// Encode a single text pair + /// + /// This is used for application where the model takes 2 input sequences as an input (e.g. natural language inference). + /// This generic method handles both the case where a second input is provided and when it is not + /// (falling back to single sequence encoding) + /// + /// # Arguments + /// - `text_1` string slice for the first text + /// - `text_2` Optional string slice for the second text + /// + /// # Returns + /// - `TokenizedInput` containing the tokenized and encoded texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let text_1 = "first text to encode"; + /// let output_1 = tokenizer.encode_pair(text_1, None); + /// let text_2 = "second text to encode"; + /// let output_2 = tokenizer.encode_pair(text_1, Some(text_2)); + /// # Ok(()) + /// # } + /// ``` + pub fn encode_pair( + &self, + text_1: &str, + text_2: Option<&str>, + ) -> Result { + let encoding_input = if let Some(text_2) = text_2 { + EncodeInput::Dual( + InputSequence::Raw(Cow::Borrowed(text_1)), + InputSequence::Raw(Cow::Borrowed(text_2)), + ) + } else { + EncodeInput::Single(InputSequence::Raw(Cow::Borrowed(text_1))) + }; + let encoding = self.tokenizer.encode(encoding_input, true)?; + Ok(Self::encoding_to_tokenized_input(encoding)) + } + + /// Tokenize a text + /// + /// # Arguments + /// - `text` string slice to tokenize + /// + /// # Returns + /// - `Vec` tokenized text + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let text = "first text to encode"; + /// let output = tokenizer.tokenize(text); + /// # Ok(()) + /// # } + /// ``` + pub fn tokenize(&self, text: &str) -> Vec { + self.tokenizer + .encode(text, false) + .unwrap() + .get_tokens() + .to_vec() + } + + /// Tokenize a list of texts + /// + /// # Arguments + /// - `texts` slice of string-like references to tokenize + /// + /// # Returns + /// - `Vec>` tokenized texts + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let texts = &["first text to encode", "second text to encode"]; + /// let output = tokenizer.tokenize_list(texts); + /// # Ok(()) + /// # } + /// ``` + pub fn tokenize_list(&self, texts: &[S]) -> Vec> + where + S: AsRef + Send + Sync, + { + texts + .iter() + .map(|text| self.tokenize(text.as_ref())) + .collect() + } + + /// Tokenize a text with offsets information + /// + /// # Arguments + /// - `text` string slice to tokenize with offsets + /// + /// # Returns + /// - `Vec` tokenized text + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let text = "first text to encode"; + /// let output = tokenizer.tokenize_with_offsets(text); + /// # Ok(()) + /// # } + /// ``` + pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets { + let encoding = self.tokenizer.encode(text, false).unwrap(); + let tokens = encoding.get_tokens().to_vec(); + let offsets = encoding + .get_offsets() + .iter() + .map(|offset| { + Some(Offset { + begin: offset.0 as OffsetSize, + end: offset.1 as OffsetSize, + }) + }) + .collect(); + let reference_offsets = encoding + .get_offsets() + .iter() + .map(|offset| (offset.0 as OffsetSize..offset.1 as OffsetSize).collect()) + .collect(); + let masks = encoding + .get_special_tokens_mask() + .iter() + .map(|segment_id| { + if *segment_id == 0 { + Mask::None + } else { + Mask::Special + } + }) + .collect(); + TokensWithOffsets { + tokens, + offsets, + reference_offsets, + masks, + } + } + + /// Decode a sequence of token id to a text + /// + /// # Arguments + /// - `token_ids` slice of token ids + ///- `skip_special_token_ids` flag indicating if special token ids should be skipped during decoding + /// + /// # Returns + /// - `String` decoded text + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let token_ids = &[0, 2, 5, 9, 4, 2, 1]; + /// let skip_special_token_ids = true; + /// let output = tokenizer.decode(token_ids, skip_special_token_ids); + /// # Ok(()) + /// # } + /// ``` + pub fn decode(&self, token_ids: &[i64], skip_special_tokens: bool) -> String { + self.tokenizer + .decode( + token_ids.iter().map(|token_id| *token_id as u32).collect(), + skip_special_tokens, + ) + .unwrap() + } + + fn token_ids_with_offsets_to_encoding( + &self, + token_ids_with_offsets: TokenIdsWithOffsets, + ) -> Encoding { + let ids: Vec = token_ids_with_offsets + .ids + .iter() + .map(|token_id| *token_id as u32) + .collect(); + let type_ids = token_ids_with_offsets + .ids + .iter() + .map(|segment_id| *segment_id as u32) + .collect(); + let tokens = ids + .iter() + .map(|token_id| { + self.tokenizer + .id_to_token(*token_id) + .unwrap_or(self.tokenizer.decode(vec![*token_id], false).unwrap()) + }) + .collect(); + let words = vec![None::; ids.len()]; + let offsets = token_ids_with_offsets + .offsets + .iter() + .map(|offset| { + offset + .map(|offset| (offset.begin as usize, offset.end as usize)) + .unwrap_or((0, 0)) + }) + .collect(); + let special_tokens_mask = token_ids_with_offsets + .masks + .iter() + .map(|segment_id| match segment_id { + Mask::Special => 1, + _ => 0, + }) + .collect(); + let overflowing: Vec = vec![]; + let attention_mask = vec![1; ids.len()]; + let sequence_ranges = HashMap::new(); + Encoding::new( + ids, + type_ids, + tokens, + words, + offsets, + special_tokens_mask, + attention_mask, + overflowing, + sequence_ranges, + ) + } + + /// Post-process a sequence or sequence pair + /// + /// Adds the special token for single/pair of sequences and apply tokenizer post-processing + /// + /// # Arguments + /// - `token_ids_with_offsets_1` first sequence's `TokenIdsWithOffsets` + /// - `token_ids_with_offsets_2` optional second sequence's `TokenIdsWithOffsets` + /// + /// # Returns + /// - `TokenizedInput` psot-processed encoding for the inputs provided. + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// use rust_tokenizers::{Offset, TokenIdsWithOffsets}; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let token_ids_with_offsets_1 = TokenIdsWithOffsets { + /// ids: vec![0, 1, 2], + /// offsets: vec![ + /// Some(Offset { begin: 0, end: 1 }), + /// Some(Offset { begin: 1, end: 2 }), + /// Some(Offset { begin: 2, end: 3 }), + /// ], + /// reference_offsets: vec![vec![0], vec![1], vec![2]], + /// masks: vec![], + /// }; + /// let token_ids_with_offsets_2 = TokenIdsWithOffsets { + /// ids: vec![8, 9, 10], + /// offsets: vec![ + /// Some(Offset { begin: 3, end: 4 }), + /// Some(Offset { begin: 4, end: 5 }), + /// Some(Offset { begin: 5, end: 6 }), + /// ], + /// reference_offsets: vec![vec![3], vec![4], vec![5]], + /// masks: vec![], + /// }; + /// let output = tokenizer + /// .build_input_with_special_tokens(token_ids_with_offsets_1, Some(token_ids_with_offsets_2)); + /// # Ok(()) + /// # } + /// ``` + pub fn build_input_with_special_tokens( + &self, + token_ids_with_offsets_1: TokenIdsWithOffsets, + token_ids_with_offsets_2: Option, + ) -> TokenizedInput { + let encoding_1 = self.token_ids_with_offsets_to_encoding(token_ids_with_offsets_1); + let encoding_2 = token_ids_with_offsets_2 + .map(|encoding| self.token_ids_with_offsets_to_encoding(encoding)); + let encoding_output = self + .tokenizer + .post_process(encoding_1, encoding_2, true) + .unwrap(); + Self::encoding_to_tokenized_input(encoding_output) + } + + /// Converts a single token to a token id + /// + /// Returns the unknown token id if the item is not present in the tokenizer vocabulary. + /// + /// # Arguments + /// - `token` string slice to convert + /// + /// # Returns + /// - `i64` token id (or unknown token id if not found in the vocabulary) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// use rust_tokenizers::{Offset, TokenIdsWithOffsets}; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let token = "Hello"; + /// let output = tokenizer.token_to_id(token); + /// # Ok(()) + /// # } + /// ``` + pub fn token_to_id(&self, token: &str) -> i64 { + self.tokenizer.token_to_id(token.as_ref()).unwrap_or( + self.tokenizer + .token_to_id(self.special_token_map.unk_token.as_str()) + .unwrap(), + ) as i64 + } + + /// Converts a slice of tokens to token ids + /// + /// Returns the unknown token id if the item is not present in the tokenizer vocabulary. + /// + /// # Arguments + /// - `tokens` slice of string slices to convert + /// + /// # Returns + /// - `Vec` token ids (with unknown token id at position of items not found in the vocabulary) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// use rust_tokenizers::{Offset, TokenIdsWithOffsets}; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// let tokens = &["Hello", "world", "!"]; + /// let output = tokenizer.convert_tokens_to_ids(tokens); + /// # Ok(()) + /// # } + /// ``` + pub fn convert_tokens_to_ids(&self, tokens: &[S]) -> Vec + where + S: AsRef, + { + tokens + .iter() + .map(|token| self.token_to_id(token.as_ref())) + .collect() + } + + /// Add tokens to the tokenizer vocabulary + /// + /// These tokens are not used by the tokenization algorithm and simply added to the vocabulary + /// + /// # Arguments + /// - `tokens` tokens to add to the vocabulary + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// use rust_tokenizers::{Offset, TokenIdsWithOffsets}; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let mut tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// tokenizer.add_tokens(&["", ""]); + /// # Ok(()) + /// # } + /// ``` + pub fn add_tokens(&mut self, tokens: &[&str]) { + let added_tokens = tokens + .iter() + .map(|token| AddedToken { + content: token.to_string(), + single_word: false, + lstrip: false, + rstrip: false, + normalized: false, + special: false, + }) + .collect::>(); + self.tokenizer.add_tokens(&added_tokens); + } + + /// Add extra token ids to the tokenizer vocabulary + /// + /// These tokens are automatically formatted as "" + /// + /// # Arguments + /// - `num_extra_ids` number of tokens to add + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer; + /// # use std::path::PathBuf; + /// use rust_tokenizers::{Offset, TokenIdsWithOffsets}; + /// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json"); + /// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json"); + /// let mut tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?; + /// tokenizer.add_extra_ids(42); + /// # Ok(()) + /// # } + /// ``` + pub fn add_extra_ids(&mut self, num_extra_ids: i64) { + let mut added_tokens: Vec = Vec::with_capacity(num_extra_ids as usize); + for extra_id in 0..num_extra_ids { + added_tokens.push(AddedToken { + content: format!(""), + single_word: false, + lstrip: false, + rstrip: false, + normalized: false, + special: false, + }); + } + self.tokenizer.add_tokens(&added_tokens); + } +} diff --git a/src/pipelines/keywords_extraction/pipeline.rs b/src/pipelines/keywords_extraction/pipeline.rs index b8d9ea6..a4e18ef 100644 --- a/src/pipelines/keywords_extraction/pipeline.rs +++ b/src/pipelines/keywords_extraction/pipeline.rs @@ -220,7 +220,7 @@ impl<'a> KeywordExtractionModel<'a> { /// ``` pub fn predict(&self, inputs: &[S]) -> Result>, RustBertError> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let words = self.tokenizer.tokenize_list(inputs, self.ngram_range); let (flat_word_list, document_boundaries) = diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 7711293..e886d69 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -473,6 +473,56 @@ //! ] //! # ; //! ``` +//! +//! # [Tokenizers](https://github.com/huggingface/tokenizers) support +//! +//! The pipelines support both the default [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers) and +//! Hugging Face's [Tokenizers](https://github.com/huggingface/tokenizers) library. In order to use the latter, +//! the tokenizer needs to be created manually and passed as an argument to the pipeline's `new_with_tokenizer` method. +//! +//! Note that the `special_token_maps` is required to create a `TokenizerOption` from a HFTokenizer. This file is sometimes not provided +//! (the Python Transformers library provides the special token map information as part of the actual tokenizer loaded wrapping the rust-based +//! tokenizer). If that is the case a temporary file with the special token map information can be created as illustrated below: +//! ```no_run +//! fn main() -> anyhow::Result<()> { +//! use std::fs::File; +//! use std::io::Write; +//! use tempfile::TempDir; +//! use rust_bert::pipelines::common::{ModelType, TokenizerOption}; +//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; +//! use rust_bert::resources::{RemoteResource, ResourceProvider}; +//! +//! let generate_config = TextGenerationConfig { +//! model_type: ModelType::GPT2, +//! ..Default::default() +//! }; +//! +//! // Create tokenizer +//! let tmp_dir = TempDir::new()?; +//! let special_token_map_path = tmp_dir.path().join("special_token_map.json"); +//! let mut tmp_file = File::create(&special_token_map_path)?; +//! writeln!( +//! tmp_file, +//! r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"# +//! )?; +//! let tokenizer_path = RemoteResource::from_pretrained(( +//! "gpt2/tokenizer", +//! "https://huggingface.co/gpt2/resolve/main/tokenizer.json", +//! )).get_local_path()?; +//! let tokenizer = +//! TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?; +//! +//! // Create model +//! let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?; +//! +//! let input_context = "The dog"; +//! let output = model.generate(&[input_context], None); +//! for sentence in output { +//! println!("{sentence:?}"); +//! } +//! Ok(()) +//! } +//! ``` pub mod common; pub mod conversation; @@ -493,3 +543,6 @@ pub mod zero_shot_classification; #[cfg(feature = "onnx")] pub mod onnx; + +#[cfg(feature = "hf-tokenizers")] +pub mod hf_tokenizers; diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index f96e638..f9d8f9a 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -314,7 +314,7 @@ impl SentenceEmbeddingsModel { /// Tokenizes the inputs pub fn tokenize(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let tokenized_input = self.tokenizer.encode_list( inputs, @@ -368,7 +368,7 @@ impl SentenceEmbeddingsModel { inputs: &[S], ) -> Result where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let SentenceEmbeddingsTokenizerOutput { tokens_ids, @@ -413,7 +413,7 @@ impl SentenceEmbeddingsModel { /// Computes sentence embeddings. pub fn encode(&self, inputs: &[S]) -> Result, RustBertError> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let SentenceEmbeddingsModelOutput { embeddings, .. } = self.encode_as_tensor(inputs)?; Ok(Vec::try_from(embeddings)?) @@ -457,7 +457,7 @@ impl SentenceEmbeddingsModel { inputs: &[S], ) -> Result<(Vec, Vec), RustBertError> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let SentenceEmbeddingsModelOutput { embeddings, diff --git a/src/pipelines/summarization.rs b/src/pipelines/summarization.rs index 4dd08bd..ed5a26f 100644 --- a/src/pipelines/summarization.rs +++ b/src/pipelines/summarization.rs @@ -336,7 +336,7 @@ impl SummarizationOption { /// Interface method to generate() of the particular models. pub fn generate(&self, prompt_texts: Option<&[S]>) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { match *self { Self::Bart(ref model) => model @@ -504,7 +504,7 @@ impl SummarizationModel { /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) pub fn summarize(&self, texts: &[S]) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { match &self.prefix { None => self.model.generate(Some(texts)), diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index c434e54..97c2307 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -333,7 +333,7 @@ impl TextGenerationOption { max_length: Option, ) -> Vec> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let generate_options = Some(GenerateOptions { min_length, @@ -597,7 +597,7 @@ with people, even a bishop, begging for his blessing. " /// ``` pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into>) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let (prefix, prefix_length) = match (prefix.into(), &self.prefix) { (Some(query_prefix), _) => ( diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 4629a53..3795727 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -128,7 +128,6 @@ use crate::resources::ResourceProvider; use crate::roberta::RobertaForTokenClassification; use crate::xlnet::XLNetForTokenClassification; use ordered_float::OrderedFloat; -use rust_tokenizers::tokenizer::Tokenizer; use rust_tokenizers::{ ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait, TokenizedInput, @@ -1103,27 +1102,7 @@ impl TokenClassificationModel { let offsets = &sentence_tokens.offsets[position_idx as usize]; let text = match offsets { - None => match self.tokenizer { - TokenizerOption::Bert(ref tokenizer) => { - Tokenizer::decode(tokenizer, &[token_id], false, false) - } - TokenizerOption::Roberta(ref tokenizer) => { - Tokenizer::decode(tokenizer, &[token_id], false, false) - } - TokenizerOption::XLMRoberta(ref tokenizer) => { - Tokenizer::decode(tokenizer, &[token_id], false, false) - } - TokenizerOption::Albert(ref tokenizer) => { - Tokenizer::decode(tokenizer, &[token_id], false, false) - } - TokenizerOption::XLNet(ref tokenizer) => { - Tokenizer::decode(tokenizer, &[token_id], false, false) - } - _ => panic!( - "Token classification not implemented for {:?}!", - self.tokenizer.model_type() - ), - }, + None => self.tokenizer.decode(&[token_id], false, false), Some(offsets) => { let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize); let end_char = min(end_char, original_sentence_chars.len()); diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index 6445a9b..f4ff01e 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -1217,7 +1217,7 @@ impl TranslationOption { forced_bos_token_id: Option, ) -> Vec where - S: AsRef + Sync, + S: AsRef + Send + Sync, { match *self { Self::Marian(ref model) => model @@ -1470,7 +1470,7 @@ impl TranslationModel { target_language: impl Into>, ) -> Result, RustBertError> where - S: AsRef + Sync, + S: AsRef + Send + Sync, { let (prefix, forced_bos_token_id) = self.model.get_tokenizer().get_prefix_and_forced_bos_id( diff --git a/tests/hf_tokenizers.rs b/tests/hf_tokenizers.rs new file mode 100644 index 0000000..667d3a4 --- /dev/null +++ b/tests/hf_tokenizers.rs @@ -0,0 +1,112 @@ +#[cfg(feature = "hf-tokenizers")] +mod tests { + use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2ModelResources}; + use rust_bert::pipelines::common::{ModelResource, ModelType, TokenizerOption}; + use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel}; + use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; + use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider}; + use std::fs::File; + use std::io::Write; + use tch::Device; + use tempfile::TempDir; + + #[test] + fn gpt2_generation() -> anyhow::Result<()> { + let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); + let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); + let dummy_vocab_resource = Box::new(LocalResource { + local_path: Default::default(), + }); + let tokenizer_resource = Box::new(RemoteResource::from_pretrained(( + "gpt2/tokenizer", + "https://huggingface.co/gpt2/resolve/main/tokenizer.json", + ))); + + let generate_config = TextGenerationConfig { + model_type: ModelType::GPT2, + model_resource: ModelResource::Torch(model_resource), + config_resource, + vocab_resource: dummy_vocab_resource, + merges_resource: None, + max_length: Some(20), + do_sample: false, + num_beams: 5, + temperature: 1.2, + device: Device::Cpu, + num_return_sequences: 3, + ..Default::default() + }; + + // Create tokenizer + let tmp_dir = TempDir::new()?; + let special_token_map_path = tmp_dir.path().join("special_token_map.json"); + let mut tmp_file = File::create(&special_token_map_path)?; + writeln!( + tmp_file, + r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"# + )?; + + let tokenizer_path = tokenizer_resource.get_local_path()?; + let tokenizer = + TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?; + + let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?; + + let input_context = "The dog"; + let output = model.generate(&[input_context], None); + + assert_eq!(output.len(), 3); + assert_eq!( + output[0], + "The dog was found in the backyard of a home in the 6200 block of South Main Street." + ); + assert_eq!( + output[1], + "The dog was found in the backyard of a home in the 6500 block of South Main Street." + ); + assert_eq!( + output[2], + "The dog was found in the backyard of a home in the 6200 block of South Main Street," + ); + Ok(()) + } + + #[test] + fn distilbert_question_answering() -> anyhow::Result<()> { + // Create tokenizer + let tmp_dir = TempDir::new()?; + let special_token_map_path = tmp_dir.path().join("special_token_map.json"); + let mut tmp_file = File::create(&special_token_map_path)?; + + writeln!( + tmp_file, + r#"{{"pad_token": "[PAD]", "sep_token": "[SEP]", "cls_token": "[CLS]", "mask_token": "[MASK]", "unk_token": "[UNK]"}}"# + )?; + let tokenizer_resource = Box::new(RemoteResource::from_pretrained(( + "distilbert-base-cased-distilled-squad/tokenizer", + "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/tokenizer.json", + ))); + let tokenizer_path = tokenizer_resource.get_local_path()?; + let tokenizer = + TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?; + + // Set-up question answering model + let qa_model = QuestionAnsweringModel::new_with_tokenizer(Default::default(), tokenizer)?; + + // Define input + let question = String::from("Where does Amy live ?"); + let context = String::from("Amy lives in Amsterdam"); + let qa_input = QaInput { question, context }; + + let answers = qa_model.predict(&[qa_input], 1, 32); + + assert_eq!(answers.len(), 1usize); + assert_eq!(answers[0].len(), 1usize); + assert_eq!(answers[0][0].start, 13); + assert_eq!(answers[0][0].end, 22); + assert!((answers[0][0].score - 0.9978).abs() < 1e-4); + assert_eq!(answers[0][0].answer, "Amsterdam"); + + Ok(()) + } +}