Support for HF Tokenizers (#408)

* tokenizers output type conversion

* WIP hf Tokenizers support (2)

* finalize interface methods for hf tokenizers

* Addition of GPT2 example with hf tokenizers

* Made hf-tokenizers optional, added doc for HFTokenizer

* Addition of tests for hf tokenizers, addition to CI

* Updated changelog, extended documentation

* Fix Clippy warnings
This commit is contained in:
guillaume-be 2023-08-13 11:09:02 +01:00 committed by GitHub
parent af3839e91c
commit fd1e66b1c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1096 additions and 64 deletions

View File

@ -140,8 +140,8 @@ jobs:
--test nllb
--features download-libtorch
test-onnx:
name: Integration tests (ONNX models)
test-opt-features:
name: Integration tests (Optional features)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
@ -155,7 +155,9 @@ jobs:
command: test
args: --package rust-bert
--features onnx
--features hf-tokenizers
--test onnx
--test hf_tokenizers
--features download-libtorch
convert-model:

View File

@ -3,7 +3,8 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased]
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).

View File

@ -69,6 +69,7 @@ download-libtorch = ["tch/download-libtorch"]
onnx = ["ort", "ndarray"]
rustls-tls = ["cached-path/rustls-tls"]
default-tls = ["cached-path/default-tls"]
hf-tokenizers = ["tokenizers"]
[package.metadata.docs.rs]
features = ["doc-only"]
@ -89,6 +90,7 @@ dirs = { version = "4", optional = true }
lazy_static = { version = "1", optional = true }
ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]}
ndarray = {version="0.15", optional = true}
tokenizers = {version="0.13.3", optional=true, default-features = false, features = ["onig"]}
[dev-dependencies]
anyhow = "1"

View File

@ -0,0 +1,61 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::{ModelType, TokenizerOption};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use std::fs::File;
use std::io::Write;
use tempfile::TempDir;
fn main() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
};
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"#
)?;
let tokenizer_path = RemoteResource::from_pretrained((
"gpt2/tokenizer",
"https://huggingface.co/gpt2/resolve/main/tokenizer.json",
))
.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
for sentence in output {
println!("{sentence:?}");
}
Ok(())
}

View File

@ -114,7 +114,9 @@ fn make_global_fixed_block_ids(
attention_mask: &Tensor,
global_block_size: i64,
) -> (Tensor, Tensor) {
let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {unreachable!()};
let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {
unreachable!()
};
let handle_orphan_tokens = |block_ids: Tensor| -> Tensor {
let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device()))

View File

@ -56,13 +56,18 @@ use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use tch::{Device, Kind, Tensor};
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXModelConfig;
#[cfg(feature = "hf-tokenizers")]
use crate::pipelines::hf_tokenizers::HFTokenizer;
#[derive(Debug, Default)]
/// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past)
pub struct ONNXModelResources {
@ -288,6 +293,9 @@ pub enum TokenizerOption {
FNet(FNetTokenizer),
/// Bart Tokenizer
Bart(RobertaTokenizer),
/// HF Tokenizer
#[cfg(feature = "hf-tokenizers")]
HFTokenizer(HFTokenizer),
}
impl ConfigOption {
@ -913,28 +921,13 @@ impl TokenizerOption {
Ok(tokenizer)
}
/// Returns the model type
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Roberta(_) => ModelType::Roberta,
Self::Bart(_) => ModelType::Bart,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::GPT2(_) => ModelType::GPT2,
Self::OpenAiGpt(_) => ModelType::OpenAiGpt,
Self::Reformer(_) => ModelType::Reformer,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
Self::MBart50(_) => ModelType::MBart,
Self::M2M100(_) | Self::NLLB(_) => ModelType::M2M100,
Self::FNet(_) => ModelType::FNet,
}
#[cfg(feature = "hf-tokenizers")]
pub fn from_hf_tokenizer_file<P: AsRef<Path>, S: AsRef<Path>>(
tokenizer_file: P,
special_token_map: S,
) -> Result<Self, RustBertError> {
let hf_tokenizer = HFTokenizer::from_file(tokenizer_file, special_token_map)?;
Ok(TokenizerOption::HFTokenizer(hf_tokenizer))
}
/// Interface method
@ -946,7 +939,7 @@ impl TokenizerOption {
stride: usize,
) -> Vec<TokenizedInput>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
@ -1082,6 +1075,8 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.encode_list(text_list).unwrap(),
}
}
@ -1227,6 +1222,8 @@ impl TokenizerOption {
truncation_strategy,
stride,
),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair_list(text_pair_list).unwrap(),
}
}
@ -1297,6 +1294,8 @@ impl TokenizerOption {
Self::FNet(ref tokenizer) => {
tokenizer.encode(text_1, text_2, max_len, truncation_strategy, stride)
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.encode_pair(text_1, text_2).unwrap(),
}
}
@ -1322,6 +1321,8 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => tokenizer.tokenize(text),
Self::NLLB(ref tokenizer) => tokenizer.tokenize(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize(text),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize(text),
}
}
@ -1347,13 +1348,15 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::NLLB(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
Self::FNet(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_with_offsets(text),
}
}
/// Interface method to tokenization
pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
@ -1383,6 +1386,8 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::NLLB(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::FNet(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.tokenize_list(text),
}
}
@ -1451,6 +1456,8 @@ impl TokenizerOption {
Self::FNet(ref tokenizer) => {
tokenizer.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces)
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.decode(token_ids, skip_special_tokens),
}
}
@ -1537,6 +1544,13 @@ impl TokenizerOption {
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => {
return tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
)
}
};
TokenizedInput {
token_ids: token_ids_with_special_tokens.token_ids,
@ -1736,6 +1750,8 @@ impl TokenizerOption {
Self::M2M100(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::NLLB(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::FNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
@ -1818,6 +1834,10 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
vocab.token_to_id(vocab.get_unknown_value())
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => {
tokenizer.token_to_id(&tokenizer.special_token_map.unk_token)
}
}
}
@ -1888,6 +1908,12 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_pad_value()))
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer
.special_token_map
.pad_token
.as_ref()
.map(|token| tokenizer.token_to_id(token)),
Self::Reformer(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
@ -1949,6 +1975,12 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_sep_value()))
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer
.special_token_map
.sep_token
.as_ref()
.map(|token| tokenizer.token_to_id(token)),
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
@ -2009,6 +2041,12 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_mask_value()))
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer
.special_token_map
.mask_token
.as_ref()
.map(|token| tokenizer.token_to_id(token)),
Self::Marian(_) => None,
Self::M2M100(_) => None,
Self::NLLB(_) => None,
@ -2058,6 +2096,8 @@ impl TokenizerOption {
Self::Pegasus(ref tokenizer) => {
Some(MultiThreadedTokenizer::vocab(tokenizer).get_mask_value())
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer.special_token_map.mask_token.as_deref(),
Self::M2M100(_) => None,
Self::NLLB(_) => None,
Self::Marian(_) => None,
@ -2111,6 +2151,12 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_bos_value()))
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer
.special_token_map
.bos_token
.as_ref()
.map(|token| tokenizer.token_to_id(token)),
Self::MBart50(_) => Some(0),
Self::FNet(_) => None,
Self::Bert(_) => None,
@ -2186,6 +2232,12 @@ impl TokenizerOption {
let vocab = MultiThreadedTokenizer::vocab(tokenizer);
Some(vocab.token_to_id(vocab.get_eos_value()))
}
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref tokenizer) => tokenizer
.special_token_map
.eos_token
.as_ref()
.map(|token| tokenizer.token_to_id(token)),
Self::FNet(_) => None,
Self::Bert(_) => None,
Self::ProphetNet(_) => None,
@ -2264,6 +2316,8 @@ impl TokenizerOption {
Self::M2M100(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
Self::NLLB(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
Self::FNet(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_extra_ids(num_extra_ids),
}
}
@ -2289,6 +2343,8 @@ impl TokenizerOption {
Self::M2M100(ref mut tokenizer) => tokenizer.add_tokens(tokens),
Self::NLLB(ref mut tokenizer) => tokenizer.add_tokens(tokens),
Self::FNet(ref mut tokenizer) => tokenizer.add_tokens(tokens),
#[cfg(feature = "hf-tokenizers")]
Self::HFTokenizer(ref mut tokenizer) => tokenizer.add_tokens(tokens),
}
}
}

View File

@ -381,7 +381,7 @@ pub(crate) mod private_generation_utils {
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let token_ids = if self.is_encoder_decoder() {
let tokens = self._get_tokenizer().encode_list(
@ -1774,7 +1774,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedTextOutput>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let indices_outputs = self.generate_indices(prompt_texts, generate_options);
let mut output = Vec::with_capacity(indices_outputs.len());
@ -1868,7 +1868,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
generate_options: Option<GenerateOptions>,
) -> Vec<GeneratedIndicesOutput>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let eos_token_ids = self.get_eos_ids();

View File

@ -0,0 +1,764 @@
/// # Support for [tokenizers](https://github.com/huggingface/tokenizers)
///
/// This module implements interface methods to allow loading tokenizers trained and implemented with
/// the [Tokenizers](https://github.com/huggingface/tokenizers) crate. While the functionality of these tokenizers
/// is expected to be identical to the default [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers) used
/// in this crate, the implementation and input file format differs.
///
/// Because some of the logic related to the special token handling is implemented at the Python level using the rust bindings,
/// the proposed implementation requires two files to be provided:
/// - `tokenizer.json` containing the tokenizer model, pre- and post-processing options and vocabulary
/// - `special_token_map.json` containing a mapping of the special tokens used by the model (e.g. BOS and CLS values)
use crate::RustBertError;
use rust_tokenizers::{
Mask, Offset, OffsetSize, TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets,
};
use serde::{de, Deserialize, Deserializer};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use tokenizers::tokenizer::Tokenizer as HFBaseTokenizer;
use tokenizers::{AddedToken, EncodeInput, Encoding, InputSequence};
impl From<tokenizers::tokenizer::Error> for RustBertError {
fn from(error: tokenizers::tokenizer::Error) -> Self {
RustBertError::TokenizerError(error.to_string())
}
}
/// Container for a special token map to be deserialized from a `special_token_map.json`
#[derive(Debug, Default, Clone, Deserialize)]
pub struct SpecialTokenMap {
/// Unknown token (must be provided for all tokenizers)
pub unk_token: String,
/// Optional padding token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub pad_token: Option<String>,
/// Optional bos token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub bos_token: Option<String>,
/// Optional sep token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub sep_token: Option<String>,
/// Optional cls token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub cls_token: Option<String>,
/// Optional eos token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub eos_token: Option<String>,
/// Optional mask token
#[serde(default)]
#[serde(deserialize_with = "string_or_added_token_struct")]
pub mask_token: Option<String>,
/// Optional additional special tokens
pub additional_special_tokens: Option<HashSet<String>>,
}
/// Deserialization utility function for `special_token_map.json` to read nested special tokens structure
fn string_or_added_token_struct<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
struct StringOrStruct;
impl<'de> de::Visitor<'de> for StringOrStruct {
type Value = Option<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string or map")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: de::MapAccess<'de>,
{
let mut value = None;
while let Some(key) = map.next_key::<String>()? {
if key == "content" {
value = Some(map.next_value::<String>()?);
} else {
_ = map.next_value::<String>();
}
}
Ok(value)
}
}
deserializer.deserialize_any(StringOrStruct)
}
/// Base class for a tokenizer from the Tokenizers library
pub struct HFTokenizer {
/// Base tokenizer object
tokenizer: HFBaseTokenizer,
/// Special token map
pub(crate) special_token_map: SpecialTokenMap,
}
impl HFTokenizer {
/// Create a new tokenizer from a file.
///
/// # Arguments
/// - `tokenizer_file` path to location containing the tokenizer model, pre- and post-processing options and vocabulary
/// - `special_token_map` path to location containing a mapping of the special tokens used by the model (e.g. BOS and CLS values)
///
/// # Returns
/// - Wrapper around a tokenizer that can be loaded in a `TokenizerOption` in this crate
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// use std::path::PathBuf;
/// let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// # Ok(())
/// # }
/// ```
pub fn from_file<P: AsRef<Path>, S: AsRef<Path>>(
tokenizer_file: P,
special_token_map: S,
) -> Result<Self, RustBertError> {
let tokenizer = HFBaseTokenizer::from_file(tokenizer_file)?;
let f = File::open(&special_token_map).map_err(|e| {
RustBertError::IOError(format!(
"{} special token map file not found :{}",
special_token_map.as_ref().display(),
e
))
})?;
let br = BufReader::new(f);
let special_token_map = serde_json::from_reader(br).map_err(|e| {
RustBertError::IOError(format!("Invalid special token mapping file {e}"))
})?;
Ok(Self {
tokenizer,
special_token_map,
})
}
fn encoding_to_tokenized_input(encoding: Encoding) -> TokenizedInput {
let token_ids = encoding
.get_ids()
.iter()
.map(|token_id| *token_id as i64)
.collect();
let segment_ids = encoding
.get_type_ids()
.iter()
.map(|segment_id| *segment_id as i8)
.collect();
let special_tokens_mask = encoding
.get_special_tokens_mask()
.iter()
.map(|segment_id| *segment_id as i8)
.collect();
let overflowing_tokens: Vec<i64> = encoding
.get_overflowing()
.iter()
.flat_map(|encoding| encoding.get_ids())
.map(|token_id| *token_id as i64)
.collect();
let num_truncated_tokens = overflowing_tokens.len();
let token_offsets = encoding
.get_offsets()
.iter()
.map(|offset| {
Some(Offset {
begin: offset.0 as OffsetSize,
end: offset.1 as OffsetSize,
})
})
.collect();
let reference_offsets = encoding
.get_offsets()
.iter()
.map(|offset| (offset.0 as OffsetSize..offset.1 as OffsetSize).collect())
.collect();
let mask = encoding
.get_special_tokens_mask()
.iter()
.map(|segment_id| {
if *segment_id == 0 {
Mask::None
} else {
Mask::Special
}
})
.collect();
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
}
}
/// Encode a list of texts
///
/// # Arguments
/// - `text_list` slice of string-like inputs to encode
///
/// # Returns
/// - `Vec<TokenizedInput>` containing the tokenized and encoded texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let texts = &["first text to encode", "second text to encode"];
/// let output = tokenizer.encode_list(texts);
/// # Ok(())
/// # }
/// ```
pub fn encode_list<S>(&self, text_list: &[S]) -> Result<Vec<TokenizedInput>, RustBertError>
where
S: AsRef<str> + Sync + Send,
{
let encoding_inputs = text_list.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(encoding_inputs, true)?;
let mut tokenized_inputs: Vec<TokenizedInput> = Vec::with_capacity(encodings.len());
for encoding in encodings {
tokenized_inputs.push(Self::encoding_to_tokenized_input(encoding));
}
Ok(tokenized_inputs)
}
/// Encode a list of text pairs
///
/// This is used for application where the model takes 2 input sequences as an input (e.g. natural language inference).
///
/// # Arguments
/// - `text_pair_list` slice of tuples of string-like inputs to encode
///
/// # Returns
/// - `Vec<TokenizedInput>` containing the tokenized and encoded texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let texts = &[
/// (
/// "first text of first pair to encode",
/// "second text of first pair to encode",
/// ),
/// (
/// "first text of second pair to encode",
/// "second text of second pair to encode",
/// ),
/// ];
/// let output = tokenizer.encode_pair_list(texts);
/// # Ok(())
/// # }
/// ```
pub fn encode_pair_list(
&self,
text_pair_list: &[(&str, &str)],
) -> Result<Vec<TokenizedInput>, RustBertError> {
let encoding_inputs: Vec<EncodeInput> = text_pair_list
.iter()
.map(|(text_1, text_2)| {
EncodeInput::Dual(
InputSequence::Raw(Cow::Borrowed(text_1)),
InputSequence::Raw(Cow::Borrowed(text_2)),
)
})
.collect();
let encodings = self.tokenizer.encode_batch(encoding_inputs, true)?;
let mut tokenized_inputs: Vec<TokenizedInput> = Vec::with_capacity(encodings.len());
for encoding in encodings {
tokenized_inputs.push(Self::encoding_to_tokenized_input(encoding));
}
Ok(tokenized_inputs)
}
/// Encode a single text pair
///
/// This is used for application where the model takes 2 input sequences as an input (e.g. natural language inference).
/// This generic method handles both the case where a second input is provided and when it is not
/// (falling back to single sequence encoding)
///
/// # Arguments
/// - `text_1` string slice for the first text
/// - `text_2` Optional string slice for the second text
///
/// # Returns
/// - `TokenizedInput` containing the tokenized and encoded texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let text_1 = "first text to encode";
/// let output_1 = tokenizer.encode_pair(text_1, None);
/// let text_2 = "second text to encode";
/// let output_2 = tokenizer.encode_pair(text_1, Some(text_2));
/// # Ok(())
/// # }
/// ```
pub fn encode_pair(
&self,
text_1: &str,
text_2: Option<&str>,
) -> Result<TokenizedInput, RustBertError> {
let encoding_input = if let Some(text_2) = text_2 {
EncodeInput::Dual(
InputSequence::Raw(Cow::Borrowed(text_1)),
InputSequence::Raw(Cow::Borrowed(text_2)),
)
} else {
EncodeInput::Single(InputSequence::Raw(Cow::Borrowed(text_1)))
};
let encoding = self.tokenizer.encode(encoding_input, true)?;
Ok(Self::encoding_to_tokenized_input(encoding))
}
/// Tokenize a text
///
/// # Arguments
/// - `text` string slice to tokenize
///
/// # Returns
/// - `Vec<String>` tokenized text
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let text = "first text to encode";
/// let output = tokenizer.tokenize(text);
/// # Ok(())
/// # }
/// ```
pub fn tokenize(&self, text: &str) -> Vec<String> {
self.tokenizer
.encode(text, false)
.unwrap()
.get_tokens()
.to_vec()
}
/// Tokenize a list of texts
///
/// # Arguments
/// - `texts` slice of string-like references to tokenize
///
/// # Returns
/// - `Vec<Vec<String>>` tokenized texts
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let texts = &["first text to encode", "second text to encode"];
/// let output = tokenizer.tokenize_list(texts);
/// # Ok(())
/// # }
/// ```
pub fn tokenize_list<S>(&self, texts: &[S]) -> Vec<Vec<String>>
where
S: AsRef<str> + Send + Sync,
{
texts
.iter()
.map(|text| self.tokenize(text.as_ref()))
.collect()
}
/// Tokenize a text with offsets information
///
/// # Arguments
/// - `text` string slice to tokenize with offsets
///
/// # Returns
/// - `Vec<String>` tokenized text
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let text = "first text to encode";
/// let output = tokenizer.tokenize_with_offsets(text);
/// # Ok(())
/// # }
/// ```
pub fn tokenize_with_offsets(&self, text: &str) -> TokensWithOffsets {
let encoding = self.tokenizer.encode(text, false).unwrap();
let tokens = encoding.get_tokens().to_vec();
let offsets = encoding
.get_offsets()
.iter()
.map(|offset| {
Some(Offset {
begin: offset.0 as OffsetSize,
end: offset.1 as OffsetSize,
})
})
.collect();
let reference_offsets = encoding
.get_offsets()
.iter()
.map(|offset| (offset.0 as OffsetSize..offset.1 as OffsetSize).collect())
.collect();
let masks = encoding
.get_special_tokens_mask()
.iter()
.map(|segment_id| {
if *segment_id == 0 {
Mask::None
} else {
Mask::Special
}
})
.collect();
TokensWithOffsets {
tokens,
offsets,
reference_offsets,
masks,
}
}
/// Decode a sequence of token id to a text
///
/// # Arguments
/// - `token_ids` slice of token ids
///- `skip_special_token_ids` flag indicating if special token ids should be skipped during decoding
///
/// # Returns
/// - `String` decoded text
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let token_ids = &[0, 2, 5, 9, 4, 2, 1];
/// let skip_special_token_ids = true;
/// let output = tokenizer.decode(token_ids, skip_special_token_ids);
/// # Ok(())
/// # }
/// ```
pub fn decode(&self, token_ids: &[i64], skip_special_tokens: bool) -> String {
self.tokenizer
.decode(
token_ids.iter().map(|token_id| *token_id as u32).collect(),
skip_special_tokens,
)
.unwrap()
}
fn token_ids_with_offsets_to_encoding(
&self,
token_ids_with_offsets: TokenIdsWithOffsets,
) -> Encoding {
let ids: Vec<u32> = token_ids_with_offsets
.ids
.iter()
.map(|token_id| *token_id as u32)
.collect();
let type_ids = token_ids_with_offsets
.ids
.iter()
.map(|segment_id| *segment_id as u32)
.collect();
let tokens = ids
.iter()
.map(|token_id| {
self.tokenizer
.id_to_token(*token_id)
.unwrap_or(self.tokenizer.decode(vec![*token_id], false).unwrap())
})
.collect();
let words = vec![None::<u32>; ids.len()];
let offsets = token_ids_with_offsets
.offsets
.iter()
.map(|offset| {
offset
.map(|offset| (offset.begin as usize, offset.end as usize))
.unwrap_or((0, 0))
})
.collect();
let special_tokens_mask = token_ids_with_offsets
.masks
.iter()
.map(|segment_id| match segment_id {
Mask::Special => 1,
_ => 0,
})
.collect();
let overflowing: Vec<Encoding> = vec![];
let attention_mask = vec![1; ids.len()];
let sequence_ranges = HashMap::new();
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens_mask,
attention_mask,
overflowing,
sequence_ranges,
)
}
/// Post-process a sequence or sequence pair
///
/// Adds the special token for single/pair of sequences and apply tokenizer post-processing
///
/// # Arguments
/// - `token_ids_with_offsets_1` first sequence's `TokenIdsWithOffsets`
/// - `token_ids_with_offsets_2` optional second sequence's `TokenIdsWithOffsets`
///
/// # Returns
/// - `TokenizedInput` psot-processed encoding for the inputs provided.
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// use rust_tokenizers::{Offset, TokenIdsWithOffsets};
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let token_ids_with_offsets_1 = TokenIdsWithOffsets {
/// ids: vec![0, 1, 2],
/// offsets: vec![
/// Some(Offset { begin: 0, end: 1 }),
/// Some(Offset { begin: 1, end: 2 }),
/// Some(Offset { begin: 2, end: 3 }),
/// ],
/// reference_offsets: vec![vec![0], vec![1], vec![2]],
/// masks: vec![],
/// };
/// let token_ids_with_offsets_2 = TokenIdsWithOffsets {
/// ids: vec![8, 9, 10],
/// offsets: vec![
/// Some(Offset { begin: 3, end: 4 }),
/// Some(Offset { begin: 4, end: 5 }),
/// Some(Offset { begin: 5, end: 6 }),
/// ],
/// reference_offsets: vec![vec![3], vec![4], vec![5]],
/// masks: vec![],
/// };
/// let output = tokenizer
/// .build_input_with_special_tokens(token_ids_with_offsets_1, Some(token_ids_with_offsets_2));
/// # Ok(())
/// # }
/// ```
pub fn build_input_with_special_tokens(
&self,
token_ids_with_offsets_1: TokenIdsWithOffsets,
token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
) -> TokenizedInput {
let encoding_1 = self.token_ids_with_offsets_to_encoding(token_ids_with_offsets_1);
let encoding_2 = token_ids_with_offsets_2
.map(|encoding| self.token_ids_with_offsets_to_encoding(encoding));
let encoding_output = self
.tokenizer
.post_process(encoding_1, encoding_2, true)
.unwrap();
Self::encoding_to_tokenized_input(encoding_output)
}
/// Converts a single token to a token id
///
/// Returns the unknown token id if the item is not present in the tokenizer vocabulary.
///
/// # Arguments
/// - `token` string slice to convert
///
/// # Returns
/// - `i64` token id (or unknown token id if not found in the vocabulary)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// use rust_tokenizers::{Offset, TokenIdsWithOffsets};
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let token = "Hello";
/// let output = tokenizer.token_to_id(token);
/// # Ok(())
/// # }
/// ```
pub fn token_to_id(&self, token: &str) -> i64 {
self.tokenizer.token_to_id(token.as_ref()).unwrap_or(
self.tokenizer
.token_to_id(self.special_token_map.unk_token.as_str())
.unwrap(),
) as i64
}
/// Converts a slice of tokens to token ids
///
/// Returns the unknown token id if the item is not present in the tokenizer vocabulary.
///
/// # Arguments
/// - `tokens` slice of string slices to convert
///
/// # Returns
/// - `Vec<i64>` token ids (with unknown token id at position of items not found in the vocabulary)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// use rust_tokenizers::{Offset, TokenIdsWithOffsets};
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// let tokens = &["Hello", "world", "!"];
/// let output = tokenizer.convert_tokens_to_ids(tokens);
/// # Ok(())
/// # }
/// ```
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
where
S: AsRef<str>,
{
tokens
.iter()
.map(|token| self.token_to_id(token.as_ref()))
.collect()
}
/// Add tokens to the tokenizer vocabulary
///
/// These tokens are not used by the tokenization algorithm and simply added to the vocabulary
///
/// # Arguments
/// - `tokens` tokens to add to the vocabulary
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// use rust_tokenizers::{Offset, TokenIdsWithOffsets};
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let mut tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// tokenizer.add_tokens(&["<CLS>", "<SEP>"]);
/// # Ok(())
/// # }
/// ```
pub fn add_tokens(&mut self, tokens: &[&str]) {
let added_tokens = tokens
.iter()
.map(|token| AddedToken {
content: token.to_string(),
single_word: false,
lstrip: false,
rstrip: false,
normalized: false,
special: false,
})
.collect::<Vec<AddedToken>>();
self.tokenizer.add_tokens(&added_tokens);
}
/// Add extra token ids to the tokenizer vocabulary
///
/// These tokens are automatically formatted as "<extra_id_{extra_id}>"
///
/// # Arguments
/// - `num_extra_ids` number of tokens to add
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::hf_tokenizers::HFTokenizer;
/// # use std::path::PathBuf;
/// use rust_tokenizers::{Offset, TokenIdsWithOffsets};
/// # let tokenizer_file_path = PathBuf::from("path/to/tokenizer.json");
/// # let special_token_map_path = PathBuf::from("path/to/special_token_map.json");
/// let mut tokenizer = HFTokenizer::from_file(tokenizer_file_path, special_token_map_path)?;
/// tokenizer.add_extra_ids(42);
/// # Ok(())
/// # }
/// ```
pub fn add_extra_ids(&mut self, num_extra_ids: i64) {
let mut added_tokens: Vec<AddedToken> = Vec::with_capacity(num_extra_ids as usize);
for extra_id in 0..num_extra_ids {
added_tokens.push(AddedToken {
content: format!("<extra_id_{extra_id}>"),
single_word: false,
lstrip: false,
rstrip: false,
normalized: false,
special: false,
});
}
self.tokenizer.add_tokens(&added_tokens);
}
}

View File

@ -220,7 +220,7 @@ impl<'a> KeywordExtractionModel<'a> {
/// ```
pub fn predict<S>(&self, inputs: &[S]) -> Result<Vec<Vec<Keyword>>, RustBertError>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let words = self.tokenizer.tokenize_list(inputs, self.ngram_range);
let (flat_word_list, document_boundaries) =

View File

@ -473,6 +473,56 @@
//! ]
//! # ;
//! ```
//!
//! # [Tokenizers](https://github.com/huggingface/tokenizers) support
//!
//! The pipelines support both the default [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers) and
//! Hugging Face's [Tokenizers](https://github.com/huggingface/tokenizers) library. In order to use the latter,
//! the tokenizer needs to be created manually and passed as an argument to the pipeline's `new_with_tokenizer` method.
//!
//! Note that the `special_token_maps` is required to create a `TokenizerOption` from a HFTokenizer. This file is sometimes not provided
//! (the Python Transformers library provides the special token map information as part of the actual tokenizer loaded wrapping the rust-based
//! tokenizer). If that is the case a temporary file with the special token map information can be created as illustrated below:
//! ```no_run
//! fn main() -> anyhow::Result<()> {
//! use std::fs::File;
//! use std::io::Write;
//! use tempfile::TempDir;
//! use rust_bert::pipelines::common::{ModelType, TokenizerOption};
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, ResourceProvider};
//!
//! let generate_config = TextGenerationConfig {
//! model_type: ModelType::GPT2,
//! ..Default::default()
//! };
//!
//! // Create tokenizer
//! let tmp_dir = TempDir::new()?;
//! let special_token_map_path = tmp_dir.path().join("special_token_map.json");
//! let mut tmp_file = File::create(&special_token_map_path)?;
//! writeln!(
//! tmp_file,
//! r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"#
//! )?;
//! let tokenizer_path = RemoteResource::from_pretrained((
//! "gpt2/tokenizer",
//! "https://huggingface.co/gpt2/resolve/main/tokenizer.json",
//! )).get_local_path()?;
//! let tokenizer =
//! TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
//!
//! // Create model
//! let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
//!
//! let input_context = "The dog";
//! let output = model.generate(&[input_context], None);
//! for sentence in output {
//! println!("{sentence:?}");
//! }
//! Ok(())
//! }
//! ```
pub mod common;
pub mod conversation;
@ -493,3 +543,6 @@ pub mod zero_shot_classification;
#[cfg(feature = "onnx")]
pub mod onnx;
#[cfg(feature = "hf-tokenizers")]
pub mod hf_tokenizers;

View File

@ -314,7 +314,7 @@ impl SentenceEmbeddingsModel {
/// Tokenizes the inputs
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let tokenized_input = self.tokenizer.encode_list(
inputs,
@ -368,7 +368,7 @@ impl SentenceEmbeddingsModel {
inputs: &[S],
) -> Result<SentenceEmbeddingsModelOutput, RustBertError>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsTokenizerOutput {
tokens_ids,
@ -413,7 +413,7 @@ impl SentenceEmbeddingsModel {
/// Computes sentence embeddings.
pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsModelOutput { embeddings, .. } = self.encode_as_tensor(inputs)?;
Ok(Vec::try_from(embeddings)?)
@ -457,7 +457,7 @@ impl SentenceEmbeddingsModel {
inputs: &[S],
) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsModelOutput {
embeddings,

View File

@ -336,7 +336,7 @@ impl SummarizationOption {
/// Interface method to generate() of the particular models.
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
match *self {
Self::Bart(ref model) => model
@ -504,7 +504,7 @@ impl SummarizationModel {
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
match &self.prefix {
None => self.model.generate(Some(texts)),

View File

@ -333,7 +333,7 @@ impl TextGenerationOption {
max_length: Option<i64>,
) -> Vec<Vec<i64>>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let generate_options = Some(GenerateOptions {
min_length,
@ -597,7 +597,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
/// ```
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
(Some(query_prefix), _) => (

View File

@ -128,7 +128,6 @@ use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification;
use ordered_float::OrderedFloat;
use rust_tokenizers::tokenizer::Tokenizer;
use rust_tokenizers::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
TokenizedInput,
@ -1103,27 +1102,7 @@ impl TokenClassificationModel {
let offsets = &sentence_tokens.offsets[position_idx as usize];
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLMRoberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Albert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLNet(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
_ => panic!(
"Token classification not implemented for {:?}!",
self.tokenizer.model_type()
),
},
None => self.tokenizer.decode(&[token_id], false, false),
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());

View File

@ -1217,7 +1217,7 @@ impl TranslationOption {
forced_bos_token_id: Option<i64>,
) -> Vec<String>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
match *self {
Self::Marian(ref model) => model
@ -1470,7 +1470,7 @@ impl TranslationModel {
target_language: impl Into<Option<Language>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Sync,
S: AsRef<str> + Send + Sync,
{
let (prefix, forced_bos_token_id) =
self.model.get_tokenizer().get_prefix_and_forced_bos_id(

112
tests/hf_tokenizers.rs Normal file
View File

@ -0,0 +1,112 @@
#[cfg(feature = "hf-tokenizers")]
mod tests {
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2ModelResources};
use rust_bert::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider};
use std::fs::File;
use std::io::Write;
use tch::Device;
use tempfile::TempDir;
#[test]
fn gpt2_generation() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let dummy_vocab_resource = Box::new(LocalResource {
local_path: Default::default(),
});
let tokenizer_resource = Box::new(RemoteResource::from_pretrained((
"gpt2/tokenizer",
"https://huggingface.co/gpt2/resolve/main/tokenizer.json",
)));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource: dummy_vocab_resource,
merges_resource: None,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 1.2,
device: Device::Cpu,
num_return_sequences: 3,
..Default::default()
};
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"#
)?;
let tokenizer_path = tokenizer_resource.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
let input_context = "The dog";
let output = model.generate(&[input_context], None);
assert_eq!(output.len(), 3);
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
Ok(())
}
#[test]
fn distilbert_question_answering() -> anyhow::Result<()> {
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"pad_token": "[PAD]", "sep_token": "[SEP]", "cls_token": "[CLS]", "mask_token": "[MASK]", "unk_token": "[UNK]"}}"#
)?;
let tokenizer_resource = Box::new(RemoteResource::from_pretrained((
"distilbert-base-cased-distilled-squad/tokenizer",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/tokenizer.json",
)));
let tokenizer_path = tokenizer_resource.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
// Set-up question answering model
let qa_model = QuestionAnsweringModel::new_with_tokenizer(Default::default(), tokenizer)?;
// Define input
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.9978).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}
}