Lib clean-up and doc landing page

This commit is contained in:
Guillaume B 2020-03-21 16:53:15 +01:00
parent 24335a5883
commit f17b0d7da8
24 changed files with 131 additions and 63 deletions

View File

@ -1,9 +1,9 @@
[package]
name = "rust-bert"
version = "0.5.2"
version = "0.5.3"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Native (Distil)BERT implementation for Rust"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
repository = "https://github.com/guillaume-be/rust-bert"
license = "Apache-2.0"
readme = "README.md"
@ -19,6 +19,13 @@ crate-type = ["lib"]
[[bin]]
name = "convert-tensor"
path = "src/convert-tensor.rs"
doc = false
[features]
doc-only = ["tch/doc-only"]
[package.metadata.docs.rs]
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = "2.0.3"

View File

@ -34,7 +34,7 @@ Extractive question answering from a given question and context. DistilBERT mode
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let answers = qa_model.predict(vec!(QaInput { question, context }), 1, 32);
let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
```
Output:
@ -56,7 +56,7 @@ This may impact the results and it is recommended to submit prompts of similar l
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(Some(input_context_1, input_context_2), 0, 30, true, false,
let output = model.generate(Some(vec!(input_context_1, input_context_2)), 0, 30, true, false,
5, 1.2, 0, 0.9, 1.0, 1.0, 3, 3, None);
```
Example output:

View File

@ -17,7 +17,8 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use failure::err_msg;
use rust_bert::{BertConfig, BertForMaskedLM, Config};
use rust_bert::bert::bert::{BertConfig, BertForMaskedLM};
use rust_bert::Config;
fn main() -> failure::Fallible<()> {

View File

@ -18,7 +18,9 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, Trunc
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use failure::err_msg;
use rust_bert::{Config, DistilBertConfig, DistilBertModelMaskedLM};
use rust_bert::distilbert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
use rust_bert::Config;
fn main() -> failure::Fallible<()> {

View File

@ -16,7 +16,7 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::{GPT2Generator, LanguageGenerator};
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator};
fn main() -> failure::Fallible<()> {

View File

@ -17,7 +17,8 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use failure::err_msg;
use rust_bert::{Gpt2Config, Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::gpt2::gpt2::{Gpt2Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::Config;
fn main() -> failure::Fallible<()> {

View File

@ -15,8 +15,8 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use rust_bert::NERModel;
use failure::err_msg;
use rust_bert::pipelines::ner::NERModel;
fn main() -> failure::Fallible<()> {
@ -48,7 +48,7 @@ fn main() -> failure::Fallible<()> {
];
// Run model
let output = ner_model.predict(input.to_vec());
let output = ner_model.predict(&input);
for entity in output {
println!("{:?}", entity);
}

View File

@ -16,8 +16,10 @@ extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::{Gpt2Config, Config, OpenAIGPTLMHeadModel, LMHeadModel};
use failure::err_msg;
use rust_bert::gpt2::gpt2::{Gpt2Config, LMHeadModel};
use rust_bert::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use rust_bert::Config;
fn main() -> failure::Fallible<()> {

View File

@ -16,7 +16,7 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::{QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
fn main() -> failure::Fallible<()> {

View File

@ -17,7 +17,9 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use failure::err_msg;
use rust_bert::{BertConfig, RobertaForMaskedLM, Config};
use rust_bert::bert::bert::BertConfig;
use rust_bert::roberta::roberta::RobertaForMaskedLM;
use rust_bert::Config;
fn main() -> failure::Fallible<()> {

View File

@ -15,8 +15,8 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use rust_bert::SentimentClassifier;
use failure::err_msg;
use rust_bert::pipelines::sentiment::SentimentClassifier;
fn main() -> failure::Fallible<()> {
@ -49,7 +49,7 @@ fn main() -> failure::Fallible<()> {
];
// Run model
let output = sentiment_classifier.predict(input.to_vec());
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}

View File

@ -17,7 +17,7 @@ use std::path::PathBuf;
use tch::Device;
use std::env;
use failure::err_msg;
use rust_bert::{QuestionAnsweringModel, squad_processor};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
fn main() -> failure::Fallible<()> {

View File

@ -1,4 +1,4 @@
pub mod bert;
pub mod embeddings;
pub mod attention;
pub mod encoder;
mod attention;
mod encoder;

View File

@ -1,23 +1,71 @@
mod distilbert;
mod bert;
mod roberta;
mod openai_gpt;
mod gpt2;
//! Ready-to-use NLP pipelines and Transformer-based models
//!
//! Rust native Transformer-based models implementation. Port of the [Transformers](https://github.com/huggingface/transformers) library, using the tch-rs crate and pre-processing from rust-tokenizers.
//! Supports multithreaded tokenization and GPU inference. This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
//!
//! # Quick Start
//!
//! This crate can be used in two different ways:
//! - Ready-to-use NLP pipelines for Sentiment Analysis, Named Entity Recognition, Question-Answering or Language Generation. More information on these can be found in the `pipelines` module.
//! ```no_run
//! use tch::Device;
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//!# use std::path::PathBuf;
//!
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert-qa");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//!
//! let device = Device::cuda_if_available();
//! let qa_model = QuestionAnsweringModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! # Ok(())
//! # }
//! ```
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
//!
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**
//! :-----:|:-----:|:-----:|:-----:|:-----:|:-----:
//! Masked LM|✅ |✅ |✅ | | |
//! Sequence classification|✅ |✅ |✅| | |
//! Token classification|✅ |✅ | ✅| | |
//! Question answering|✅ |✅ |✅| | |
//! Multiple choices| |✅ |✅| | |
//! Next token prediction| | | |✅|✅|
//! Natural Language Generation| | | |✅|✅|
//!
//! # Loading pre-trained models
//!
//! The architectures defined in this crate are compatible with model trained in the [Transformers](https://github.com/huggingface/transformers) library.
//! The model configuration and vocabulary are downloaded directly from Huggingface's repository.
//! The model weights need to be converter to a binary format that can be read by Libtorch (the original .bin files are pickles and cannot be used directly).
//! A Python script for downloading the required files & running the necessary steps is provided for all models classes in this library.
//! Further models can be loaded by extending the python scripts to point to the desired model.
//!
//!
//! 1. Compile the package: cargo build --release
//! 2. Download the model files & perform necessary conversions
//! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release
//!
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
mod common;
mod pipelines;
pub mod pipelines;
pub use common::config::Config;
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering};
pub use bert::bert::BertConfig;
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};
pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};
pub use gpt2::gpt2::{Gpt2Config, Gpt2Model, GPT2LMHeadModel, LMHeadModel};
pub use openai_gpt::openai_gpt::{OpenAiGptModel, OpenAIGPTLMHeadModel};
pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use pipelines::ner::{Entity, NERModel};
pub use pipelines::question_answering::{QaInput, QuestionAnsweringModel, squad_processor};
pub use pipelines::generation::{OpenAIGenerator, GPT2Generator, LanguageGenerator};

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::gpt2::LMHeadModel;
use crate::gpt2::gpt2::{LMHeadModel, Gpt2Config, GPT2LMHeadModel};
use tch::{Tensor, Device, nn, no_grad};
use rust_tokenizers::{Tokenizer, OpenAiGptTokenizer, OpenAiGptVocab, Vocab, TruncationStrategy, Gpt2Tokenizer, Gpt2Vocab};
use crate::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use std::path::Path;
use crate::{Gpt2Config, GPT2LMHeadModel};
use crate::common::config::Config;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use tch::kind::Kind::{Int64, Float, Bool};

View File

@ -15,11 +15,11 @@ use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::{BertForTokenClassification, BertConfig};
use std::collections::HashMap;
use crate::common::config::Config;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::bert::{BertForTokenClassification, BertConfig};
#[derive(Debug)]
@ -67,8 +67,8 @@ impl NERModel {
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: Vec<&str>) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input);
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.bert_sequence_classifier
.forward_t(Some(input_tensor.copy()),

View File

@ -17,11 +17,11 @@ use std::path::{Path, PathBuf};
use rust_tokenizers::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use std::cmp::min;
use crate::{DistilBertForQuestionAnswering, DistilBertConfig};
use tch::nn::VarStore;
use crate::common::config::Config;
use tch::kind::Kind::Float;
use std::fs;
use crate::distilbert::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig};
use crate::Config;
pub struct QaInput {
pub question: String,

View File

@ -12,12 +12,12 @@
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use crate::distilbert::distilbert::{DistilBertModelClassifier, DistilBertConfig};
use std::path::Path;
use tch::{Device, Tensor, Kind, no_grad};
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::common::config::Config;
use crate::distilbert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
#[derive(Debug, PartialEq)]
@ -68,8 +68,8 @@ impl SentimentClassifier {
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: Vec<&str>) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input);
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.distil_bert_classifier
.forward_t(Some(input_tensor),

View File

@ -4,7 +4,9 @@ extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::{NERModel, BertConfig, BertForMaskedLM, Config, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering};
use rust_bert::bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering};
use rust_bert::Config;
use rust_bert::pipelines::ner::NERModel;
#[test]
fn bert_masked_lm() -> failure::Fallible<()> {
@ -315,7 +317,7 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
];
// Run model
let output = ner_model.predict(input.to_vec());
let output = ner_model.predict(&input);
assert_eq!(output.len(), 4);

View File

@ -3,7 +3,10 @@ use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::{SentimentClassifier, SentimentPolarity, DistilBertConfig, DistilBertModelMaskedLM, Config, DistilBertForQuestionAnswering, DistilBertForTokenClassification, QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::sentiment::{SentimentClassifier, SentimentPolarity};
use rust_bert::distilbert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification};
use rust_bert::Config;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
extern crate failure;
extern crate dirs;
@ -32,7 +35,7 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = sentiment_classifier.predict(input.to_vec());
let output = sentiment_classifier.predict(&input);
assert_eq!(output.len(), 3 as usize);
assert_eq!(output[0].polarity, SentimentPolarity::Positive);

View File

@ -1,7 +1,8 @@
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::{Gpt2Config, GPT2LMHeadModel, Config, LMHeadModel};
use rust_bert::gpt2::gpt2::{Gpt2Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::Config;
#[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> {

View File

@ -1,7 +1,9 @@
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::{Gpt2Config, GPT2LMHeadModel, Config, LMHeadModel, GPT2Generator, LanguageGenerator};
use rust_bert::gpt2::gpt2::{Gpt2Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::Config;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator};
#[test]
fn gpt2_lm_model() -> failure::Fallible<()> {
@ -77,7 +79,6 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "The cat";
@ -104,7 +105,6 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "The dog";
@ -166,7 +166,6 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context_1 = "The dog";

View File

@ -1,7 +1,10 @@
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::{Gpt2Config, OpenAIGPTLMHeadModel, Config, LMHeadModel, OpenAIGenerator, LanguageGenerator};
use rust_bert::gpt2::gpt2::{Gpt2Config, LMHeadModel};
use rust_bert::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use rust_bert::Config;
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator};
#[test]
fn openai_gpt_lm_model() -> failure::Fallible<()> {
@ -73,7 +76,6 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "It was an intense machine dialogue. ";
@ -100,7 +102,6 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "The dog is";
@ -129,7 +130,6 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context_1 = "The dog is";
@ -165,7 +165,6 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context_1 = "The dog is";

View File

@ -1,7 +1,9 @@
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::{BertConfig, Config, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering};
use rust_bert::bert::bert::BertConfig;
use rust_bert::roberta::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering};
use rust_bert::Config;
#[test]
fn roberta_masked_lm() -> failure::Fallible<()> {