Updated README

This commit is contained in:
Guillaume B 2020-10-13 07:14:00 +02:00
parent 97ee8ee928
commit fbff61507a
7 changed files with 76 additions and 53 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.10.0"
version = "0.11.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"

View File

@ -10,17 +10,17 @@ This repository exposes the model base architecture, task-specific heads (see be
The following models are currently implemented:
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:
Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |
Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
Question answering|✅ |✅ |✅| | | | | |✅ | |
Multiple choices| |✅ |✅| | | | | |✅ | |
Next token prediction| | | |✅|✅|✅| | | | |
Natural Language Generation| | | |✅|✅|✅| | | | |
Summarization | | | | | |✅| | | | |
Translation | | | | | |✅| |✅ | |✅|
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:|:----:
Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅ |
Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |✅ |
Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅ |
Question answering|✅ |✅ |✅| | | | | |✅ | |✅ |
Multiple choices| |✅ |✅| | | | | |✅ | |✅ |✅ |
Next token prediction| | | |✅|✅|✅| | | | |✅ |
Natural Language Generation| | | |✅|✅|✅| | | | |✅ |
Summarization | | | | | |✅| | | | | |
Translation | | | | | |✅| |✅ | |✅| |
## Ready-to-use pipelines

View File

@ -54,7 +54,7 @@ fn main() -> anyhow::Result<()> {
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input = tokenizer.encode_list(&input, 1024, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -31,17 +31,17 @@
//! ```
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
//!
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |
//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
//! Question answering|✅ |✅ |✅| | | | | |✅ | |
//! Multiple choices| |✅ |✅| | | | | |✅ | |
//! Next token prediction| | | |✅|✅| | | | | |
//! Natural Language Generation| | | |✅|✅| | | | | |
//! Summarization| | | | | |✅| | | | |
//! Translation| | | | | | | |✅| |✅|
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet**
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅|
//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |✅|
//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅|
//! Question answering|✅ |✅ |✅| | | | | |✅ | |✅|
//! Multiple choices| |✅ |✅| | | | | |✅ | |✅|
//! Next token prediction| | | |✅|✅| | | | | |✅|
//! Natural Language Generation| | | |✅|✅| | | | | |✅|
//! Summarization| | | | | |✅| | | | | |
//! Translation| | | | | | | |✅| |✅| |
//!
//! # Loading pre-trained models
//!

View File

@ -448,13 +448,13 @@ impl TokenizerOption {
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}

View File

@ -188,7 +188,10 @@ impl NERModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
pub fn predict<'a, S>(&self, input: S) -> Vec<Entity>
where
S: AsRef<[&'a str]>,
{
self.token_classification_model
.predict(input, true, false)
.into_iter()

View File

@ -465,22 +465,32 @@ impl ZeroShotClassificationModel {
})
}
fn prepare_for_model(
fn prepare_for_model<'a, S, T>(
&self,
inputs: &[&str],
labels: &[&str],
inputs: S,
labels: T,
template: Option<Box<dyn Fn(&str) -> String>>,
max_len: usize,
) -> (Tensor, Tensor) {
) -> (Tensor, Tensor)
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let label_sentences: Vec<String> = match template {
Some(function) => labels.iter().map(|label| function(label)).collect(),
Some(function) => labels
.as_ref()
.iter()
.map(|label| function(label))
.collect(),
None => labels
.as_ref()
.iter()
.map(|label| format!("This example is about {}.", label))
.collect(),
};
let text_pair_list = inputs
.as_ref()
.iter()
.cartesian_product(label_sentences.iter())
.map(|(&s, label)| (s, label.as_str()))
@ -577,15 +587,20 @@ impl ZeroShotClassificationModel {
/// ]
/// .to_vec();
/// ```
pub fn predict(
pub fn predict<'a, S, T>(
&self,
inputs: &[&str],
labels: &[&str],
inputs: S,
labels: T,
template: Option<Box<dyn Fn(&str) -> String>>,
max_length: usize,
) -> Vec<Label> {
let num_inputs = inputs.len();
let (input_tensor, mask) = self.prepare_for_model(inputs, labels, template, max_length);
) -> Vec<Label>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(input_tensor),
@ -595,7 +610,7 @@ impl ZeroShotClassificationModel {
None,
false,
);
output.view((num_inputs as i64, labels.len() as i64, -1i64))
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
});
let scores = output.softmax(1, Float).select(-1, -1);
@ -608,7 +623,7 @@ impl ZeroShotClassificationModel {
let mut output_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = labels[label_indices[sentence_idx] as usize].to_string();
let label_string = labels.as_ref()[label_indices[sentence_idx] as usize].to_string();
let label = Label {
text: label_string,
score: scores[sentence_idx],
@ -713,15 +728,20 @@ impl ZeroShotClassificationModel {
/// ]
/// .to_vec();
/// ```
pub fn predict_multilabel(
pub fn predict_multilabel<'a, S, T>(
&self,
inputs: &[&str],
labels: &[&str],
inputs: S,
labels: T,
template: Option<Box<dyn Fn(&str) -> String>>,
max_length: usize,
) -> Vec<Vec<Label>> {
let num_inputs = inputs.len();
let (input_tensor, mask) = self.prepare_for_model(inputs, labels, template, max_length);
) -> Vec<Vec<Label>>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(input_tensor),
@ -731,7 +751,7 @@ impl ZeroShotClassificationModel {
None,
false,
);
output.view((num_inputs as i64, labels.len() as i64, -1i64))
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
});
let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);
@ -745,7 +765,7 @@ impl ZeroShotClassificationModel {
.collect::<Vec<f64>>();
for (label_index, score) in sentence_scores.into_iter().enumerate() {
let label_string = labels[label_index].to_string();
let label_string = labels.as_ref()[label_index].to_string();
let label = Label {
text: label_string,
score,