mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Updated README
This commit is contained in:
parent
97ee8ee928
commit
fbff61507a
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
|
||||
|
22
README.md
22
README.md
@ -10,17 +10,17 @@ This repository exposes the model base architecture, task-specific heads (see be
|
||||
|
||||
The following models are currently implemented:
|
||||
|
||||
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|
|
||||
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:
|
||||
Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
|
||||
Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |
|
||||
Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
|
||||
Question answering|✅ |✅ |✅| | | | | |✅ | |
|
||||
Multiple choices| |✅ |✅| | | | | |✅ | |
|
||||
Next token prediction| | | |✅|✅|✅| | | | |
|
||||
Natural Language Generation| | | |✅|✅|✅| | | | |
|
||||
Summarization | | | | | |✅| | | | |
|
||||
Translation | | | | | |✅| |✅ | |✅|
|
||||
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet**|
|
||||
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:|:----:
|
||||
Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅ |
|
||||
Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |✅ |
|
||||
Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅ |
|
||||
Question answering|✅ |✅ |✅| | | | | |✅ | |✅ |
|
||||
Multiple choices| |✅ |✅| | | | | |✅ | |✅ |✅ |
|
||||
Next token prediction| | | |✅|✅|✅| | | | |✅ |
|
||||
Natural Language Generation| | | |✅|✅|✅| | | | |✅ |
|
||||
Summarization | | | | | |✅| | | | | |
|
||||
Translation | | | | | |✅| |✅ | |✅| |
|
||||
|
||||
## Ready-to-use pipelines
|
||||
|
||||
|
@ -54,7 +54,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
|
||||
let tokenized_input = tokenizer.encode_list(&input, 1024, &TruncationStrategy::LongestFirst, 0);
|
||||
let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
|
22
src/lib.rs
22
src/lib.rs
@ -31,17 +31,17 @@
|
||||
//! ```
|
||||
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
|
||||
//!
|
||||
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**
|
||||
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
|
||||
//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
|
||||
//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |
|
||||
//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
|
||||
//! Question answering|✅ |✅ |✅| | | | | |✅ | |
|
||||
//! Multiple choices| |✅ |✅| | | | | |✅ | |
|
||||
//! Next token prediction| | | |✅|✅| | | | | |
|
||||
//! Natural Language Generation| | | |✅|✅| | | | | |
|
||||
//! Summarization| | | | | |✅| | | | |
|
||||
//! Translation| | | | | | | |✅| |✅|
|
||||
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet**
|
||||
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
|
||||
//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅|
|
||||
//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |✅|
|
||||
//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅|
|
||||
//! Question answering|✅ |✅ |✅| | | | | |✅ | |✅|
|
||||
//! Multiple choices| |✅ |✅| | | | | |✅ | |✅|
|
||||
//! Next token prediction| | | |✅|✅| | | | | |✅|
|
||||
//! Natural Language Generation| | | |✅|✅| | | | | |✅|
|
||||
//! Summarization| | | | | |✅| | | | | |
|
||||
//! Translation| | | | | | | |✅| |✅| |
|
||||
//!
|
||||
//! # Loading pre-trained models
|
||||
//!
|
||||
|
@ -448,13 +448,13 @@ impl TokenizerOption {
|
||||
/// Interface method to convert tokens to ids
|
||||
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
|
||||
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -188,7 +188,10 @@ impl NERModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
|
||||
pub fn predict<'a, S>(&self, input: S) -> Vec<Entity>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
.into_iter()
|
||||
|
@ -465,22 +465,32 @@ impl ZeroShotClassificationModel {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_for_model(
|
||||
fn prepare_for_model<'a, S, T>(
|
||||
&self,
|
||||
inputs: &[&str],
|
||||
labels: &[&str],
|
||||
inputs: S,
|
||||
labels: T,
|
||||
template: Option<Box<dyn Fn(&str) -> String>>,
|
||||
max_len: usize,
|
||||
) -> (Tensor, Tensor) {
|
||||
) -> (Tensor, Tensor)
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let label_sentences: Vec<String> = match template {
|
||||
Some(function) => labels.iter().map(|label| function(label)).collect(),
|
||||
Some(function) => labels
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|label| function(label))
|
||||
.collect(),
|
||||
None => labels
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|label| format!("This example is about {}.", label))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let text_pair_list = inputs
|
||||
.as_ref()
|
||||
.iter()
|
||||
.cartesian_product(label_sentences.iter())
|
||||
.map(|(&s, label)| (s, label.as_str()))
|
||||
@ -577,15 +587,20 @@ impl ZeroShotClassificationModel {
|
||||
/// ]
|
||||
/// .to_vec();
|
||||
/// ```
|
||||
pub fn predict(
|
||||
pub fn predict<'a, S, T>(
|
||||
&self,
|
||||
inputs: &[&str],
|
||||
labels: &[&str],
|
||||
inputs: S,
|
||||
labels: T,
|
||||
template: Option<Box<dyn Fn(&str) -> String>>,
|
||||
max_length: usize,
|
||||
) -> Vec<Label> {
|
||||
let num_inputs = inputs.len();
|
||||
let (input_tensor, mask) = self.prepare_for_model(inputs, labels, template, max_length);
|
||||
) -> Vec<Label>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(input_tensor),
|
||||
@ -595,7 +610,7 @@ impl ZeroShotClassificationModel {
|
||||
None,
|
||||
false,
|
||||
);
|
||||
output.view((num_inputs as i64, labels.len() as i64, -1i64))
|
||||
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
|
||||
});
|
||||
|
||||
let scores = output.softmax(1, Float).select(-1, -1);
|
||||
@ -608,7 +623,7 @@ impl ZeroShotClassificationModel {
|
||||
|
||||
let mut output_labels: Vec<Label> = vec![];
|
||||
for sentence_idx in 0..label_indices.len() {
|
||||
let label_string = labels[label_indices[sentence_idx] as usize].to_string();
|
||||
let label_string = labels.as_ref()[label_indices[sentence_idx] as usize].to_string();
|
||||
let label = Label {
|
||||
text: label_string,
|
||||
score: scores[sentence_idx],
|
||||
@ -713,15 +728,20 @@ impl ZeroShotClassificationModel {
|
||||
/// ]
|
||||
/// .to_vec();
|
||||
/// ```
|
||||
pub fn predict_multilabel(
|
||||
pub fn predict_multilabel<'a, S, T>(
|
||||
&self,
|
||||
inputs: &[&str],
|
||||
labels: &[&str],
|
||||
inputs: S,
|
||||
labels: T,
|
||||
template: Option<Box<dyn Fn(&str) -> String>>,
|
||||
max_length: usize,
|
||||
) -> Vec<Vec<Label>> {
|
||||
let num_inputs = inputs.len();
|
||||
let (input_tensor, mask) = self.prepare_for_model(inputs, labels, template, max_length);
|
||||
) -> Vec<Vec<Label>>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length);
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(input_tensor),
|
||||
@ -731,7 +751,7 @@ impl ZeroShotClassificationModel {
|
||||
None,
|
||||
false,
|
||||
);
|
||||
output.view((num_inputs as i64, labels.len() as i64, -1i64))
|
||||
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
|
||||
});
|
||||
let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);
|
||||
|
||||
@ -745,7 +765,7 @@ impl ZeroShotClassificationModel {
|
||||
.collect::<Vec<f64>>();
|
||||
|
||||
for (label_index, score) in sentence_scores.into_iter().enumerate() {
|
||||
let label_string = labels[label_index].to_string();
|
||||
let label_string = labels.as_ref()[label_index].to_string();
|
||||
let label = Label {
|
||||
text: label_string,
|
||||
score,
|
||||
|
Loading…
Reference in New Issue
Block a user