Addition of dependencies download example

This commit is contained in:
Guillaume B 2020-04-26 10:51:29 +02:00
parent 0a85d5ba21
commit cb254995fd
27 changed files with 586 additions and 530 deletions

View File

@ -11,39 +11,31 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer};
use failure::err_msg;
use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel};
use rust_bert::Config;
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bart-large-cnn");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_bart.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = BartConfig::from_file(config_path);
let mut bart_model = BartForConditionalGeneration::new(&vs.root(), &config, false);
let mut bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
@ -86,7 +78,7 @@ about exoplanets like K2-18b."];
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (decoder_output, encoder_output, _, _, _, _) = no_grad(|| {
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| {
bart_model
.forward_t(Some(&input_tensor),
None,

View File

@ -11,31 +11,22 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use failure::err_msg;
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM};
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_bert.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -10,34 +10,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use failure::err_msg;
use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_distilbert.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -0,0 +1,163 @@
extern crate failure;
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location).
fn download_distil_gpt2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_distilbert_sst2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_distilbert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_distilbert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_gpt2() -> failure::Fallible<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_gpt() -> failure::Fallible<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_roberta() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_bert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_bert_ner() -> failure::Fallible<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_bart() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_bart_cnn() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
let _ = download_distilbert_qa();
let _ = download_distilbert();
let _ = download_gpt2();
let _ = download_gpt();
let _ = download_roberta();
let _ = download_bert();
let _ = download_bert_ner();
let _ = download_bart();
let _ = download_bart_cnn();
Ok(())
}

View File

@ -11,33 +11,13 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("gpt2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_gpt2.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 30,
do_sample: true,
@ -46,8 +26,7 @@ fn main() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";

View File

@ -11,33 +11,25 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use failure::err_msg;
use rust_bert::Config;
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::LMHeadModel;
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("gpt2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_gpt2.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
// Resources set-up
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -11,34 +11,26 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use failure::err_msg;
use rust_bert::Config;
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::pipelines::generation::LMHeadModel;
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("openai-gpt");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_openaigpt.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -11,33 +11,25 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use failure::err_msg;
use rust_bert::Config;
use rust_bert::bert::BertConfig;
use rust_bert::roberta::RobertaForMaskedLM;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::common::resources::{Resource, download_resource, RemoteResource};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("roberta");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !merges_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_roberta.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -12,22 +12,26 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `RobertaTokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bart-large-cnn");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::common::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -16,22 +16,24 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bert");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
//! use rust_bert::Config;
//! use rust_bert::common::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);

View File

@ -6,6 +6,7 @@ use tokio::prelude::*;
extern crate dirs;
#[derive(PartialEq, Clone)]
pub enum Resource {
Local(LocalResource),
Remote(RemoteResource),
@ -20,10 +21,12 @@ impl Resource {
}
}
#[derive(PartialEq, Clone)]
pub struct LocalResource {
pub local_path: PathBuf
}
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
pub url: String,
pub local_path: PathBuf,
@ -62,6 +65,7 @@ pub async fn download_resource(resource: &Resource) -> failure::Fallible<&PathBu
let target = &remote_resource.local_path;
let url = &remote_resource.url;
if !target.exists() {
println!("Downloading {} to {:?}", url, target);
fs::create_dir_all(target.parent().unwrap())?;
let client = Client::new();

View File

@ -15,22 +15,24 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
//! use rust_bert::common::resources::{Resource, download_resource, RemoteResource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);

View File

@ -11,22 +11,26 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `Gpt2Tokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("gpt2");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::Gpt2Tokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
//! use rust_bert::common::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -15,22 +15,10 @@
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! ```no_run
//! use tch::Device;
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//!# use std::path::PathBuf;
//!
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert-qa");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//!
//! let device = Device::cuda_if_available();
//! let qa_model = QuestionAnsweringModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");

View File

@ -11,23 +11,26 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `GptTokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("openai-gpt");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::OpenAiGptTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::common::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -23,18 +23,9 @@
//! The dependencies will be downloaded to the user's home directory, under ~/rustbert/gpt2 (~/rustbert/openai-gpt respectively)
//!
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//!# fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::generation::{GenerateConfig, GPT2Generator, LanguageGenerator};
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("gpt2");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//!
//! let generate_config = GenerateConfig {
//! max_length: 30,
//! do_sample: true,
@ -43,8 +34,7 @@
//! num_return_sequences: 3,
//! ..Default::default()
//! };
//! let mut gpt2_generator = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
//! generate_config, device)?;
//! let mut gpt2_generator = GPT2Generator::new(generate_config)?;
//!
//! let input_context = "The dog";
//! let second_input_context = "The cat was";
@ -70,20 +60,28 @@
use tch::{Tensor, Device, nn, no_grad};
use rust_tokenizers::{Tokenizer, OpenAiGptTokenizer, OpenAiGptVocab, Vocab, Gpt2Tokenizer, Gpt2Vocab, RobertaTokenizer, RobertaVocab, TruncationStrategy};
use std::path::Path;
use tch::kind::Kind::Int64;
use self::ordered_float::OrderedFloat;
use itertools::Itertools;
use crate::openai_gpt::OpenAIGPTLMHeadModel;
use crate::gpt2::{Gpt2Config, GPT2LMHeadModel};
use crate::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources};
use crate::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources};
use crate::Config;
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::bart::{BartConfig, BartForConditionalGeneration};
use crate::bart::{BartConfig, BartForConditionalGeneration, BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
use crate::common::resources::{Resource, RemoteResource, download_resource};
extern crate ordered_float;
/// # Configuration for text generation
pub struct GenerateConfig {
/// Model weights resource (default: pretrained GPT2 model)
pub model_resource: Resource,
/// Config resource (default: pretrained GPT2 model)
pub config_resource: Resource,
/// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Resource,
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
/// Maximum sequence length (default: 20)
@ -108,11 +106,17 @@ pub struct GenerateConfig {
pub no_repeat_ngram_size: u64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
impl Default for GenerateConfig {
fn default() -> GenerateConfig {
GenerateConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
min_length: 0,
max_length: 20,
do_sample: true,
@ -125,6 +129,7 @@ impl Default for GenerateConfig {
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
device: Device::cuda_if_available(),
}
}
}
@ -167,27 +172,13 @@ impl OpenAIGenerator {
///
/// # Arguments
///
/// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
/// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
/// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
///# use std::path::PathBuf;
///# use tch::Device;
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::{GenerateConfig, OpenAIGenerator};
///# let mut home: PathBuf = dirs::home_dir().unwrap();
///# home.push("rustbert");
///# home.push("openai-gpt");
///# let config_path = &home.as_path().join("config.json");
///# let vocab_path = &home.as_path().join("vocab.txt");
///# let merges_path = &home.as_path().join("merges.txt");
///# let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
@ -196,21 +187,50 @@ impl OpenAIGenerator {
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let gpt_generator = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path,
/// generate_config, device)?;
/// let gpt_generator = OpenAIGenerator::new(generate_config)?;
///# Ok(())
///# }
/// ```
///
pub fn new(vocab_path: &Path, merges_path: &Path, config_path: &Path, weight_path: &Path,
generate_config: GenerateConfig, device: Device)
-> failure::Fallible<OpenAIGenerator> {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<OpenAIGenerator> {
generate_config.validate();
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT))
} else {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config);
var_store.load(weight_path)?;
var_store.load(weights_path)?;
let bos_token_id = None;
let eos_token_ids = None;
@ -257,27 +277,14 @@ impl GPT2Generator {
///
/// # Arguments
///
/// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
/// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
/// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
///# use std::path::PathBuf;
///# use tch::Device;
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::{GenerateConfig, GPT2Generator};
///# let mut home: PathBuf = dirs::home_dir().unwrap();
///# home.push("rustbert");
///# home.push("gpt2");
///# let config_path = &home.as_path().join("config.json");
///# let vocab_path = &home.as_path().join("vocab.txt");
///# let merges_path = &home.as_path().join("merges.txt");
///# let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
@ -286,21 +293,24 @@ impl GPT2Generator {
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let gpt2_generator = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
/// generate_config, device)?;
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
///# Ok(())
///# }
/// ```
///
pub fn new(vocab_path: &Path, merges_path: &Path, config_path: &Path, weight_path: &Path,
generate_config: GenerateConfig, device: Device)
-> failure::Fallible<GPT2Generator> {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<GPT2Generator> {
let config_path = download_resource(&generate_config.config_resource)?;
let vocab_path = download_resource(&generate_config.vocab_resource)?;
let merges_path = download_resource(&generate_config.merges_resource)?;
let weights_path = download_resource(&generate_config.model_resource)?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(&var_store.root(), &config);
var_store.load(weight_path)?;
var_store.load(weights_path)?;
let bos_token_id = Some(tokenizer.vocab().token_to_id(Gpt2Vocab::bos_value()));
let eos_token_ids = Some(vec!(tokenizer.vocab().token_to_id(Gpt2Vocab::eos_value())));
@ -389,21 +399,50 @@ impl BartGenerator {
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let bart_generator = BartGenerator::new(vocab_path, merges_path, config_path, weights_path,
/// generate_config, device)?;
/// let bart_generator = BartGenerator::new(generate_config)?;
///# Ok(())
///# }
/// ```
///
pub fn new(vocab_path: &Path, merges_path: &Path, config_path: &Path, weight_path: &Path,
generate_config: GenerateConfig, device: Device)
-> failure::Fallible<BartGenerator> {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<BartGenerator> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource == &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)) {
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(&var_store.root(), &config, true);
var_store.load(weight_path)?;
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
@ -1063,8 +1102,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let mut gpt2_generator = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
/// generate_config, device)?;
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
/// let input_context = "The dog";
/// let second_input_context = "The cat was";
/// let output = gpt2_generator.generate(Some(vec!(input_context, second_input_context)), None);

View File

@ -6,20 +6,9 @@
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//!
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert-qa");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let qa_model = QuestionAnsweringModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
@ -50,20 +39,11 @@
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//!
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel;
//! let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bart-large-cnn");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let mut model = SummarizationModel::new(vocab_path, merges_path, config_path, weights_path, Default::default(), device)?;
//!
//! let mut model = SummarizationModel::new(Default::default())?;
//!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
@ -110,20 +90,10 @@
//! This may impact the results and it is recommended to submit prompts of similar length for best results. Additional information on the input parameters for generation is provided in this module's documentation.
//!
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//! use rust_bert::pipelines::generation::GPT2Generator;
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("gpt2");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, Default::default(), device)?;
//! let mut model = GPT2Generator::new(Default::default())?;
//! let input_context_1 = "The dog";
//! let input_context_2 = "The cat was";
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
@ -147,26 +117,15 @@
//! #### 4. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//! use rust_bert::pipelines::sentiment::SentimentModel;
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert_sst2");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let sentiment_classifier = SentimentModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//! let sentiment_model = SentimentModel::new(Default::default())?;
//! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
//! "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sentiment_classifier.predict(&input);
//! let output = sentiment_model.predict(&input);
//!# Ok(())
//!# }
//! ```
@ -188,20 +147,9 @@
//! #### 5. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bert-ner");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let ner_model = NERModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//! let ner_model = NERModel::new(Default::default())?;
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."

View File

@ -19,20 +19,10 @@
//! The dependencies will be downloaded to the user's home directory, under ~/rustbert/bert-ner
//!
//! ```no_run
//!# use std::path::PathBuf;
//!# use tch::Device;
//! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bert-ner");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! let device = Device::cuda_if_available();
//! let ner_model = NERModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//! let ner_model = NERModel::new(Default::default())?;
//!
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."

View File

@ -65,7 +65,7 @@
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use crate::common::resources::{Resource, RemoteResource};
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
/// # Configuration for text summarization
@ -158,6 +158,10 @@ impl SummarizationModel {
pub fn new(summarization_config: SummarizationConfig)
-> failure::Fallible<SummarizationModel> {
let generate_config = GenerateConfig {
model_resource: summarization_config.model_resource,
config_resource: summarization_config.config_resource,
merges_resource: summarization_config.merges_resource,
vocab_resource: summarization_config.vocab_resource,
min_length: summarization_config.min_length,
max_length: summarization_config.max_length,
do_sample: summarization_config.do_sample,
@ -170,16 +174,10 @@ impl SummarizationModel {
length_penalty: summarization_config.length_penalty,
no_repeat_ngram_size: summarization_config.no_repeat_ngram_size,
num_return_sequences: summarization_config.num_return_sequences,
device: summarization_config.device,
};
let config_path = download_resource(&summarization_config.config_resource)?;
let vocab_path = download_resource(&summarization_config.vocab_resource)?;
let merges_path = download_resource(&summarization_config.merges_resource)?;
let weights_path = download_resource(&summarization_config.model_resource)?;
let device = summarization_config.device;
let model = BartGenerator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let model = BartGenerator::new(generate_config)?;
Ok(SummarizationModel { model })
}

View File

@ -16,23 +16,27 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `RobertaTokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bert");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let merges_path = &home.as_path().join("merges.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::bert::BertConfig;
//! use rust_bert::Config;
//! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::common::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -9,14 +9,14 @@ use rust_bert::common::resources::{Resource, RemoteResource, download_resource};
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -14,12 +14,12 @@ use std::collections::HashMap;
#[test]
fn bert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;
@ -80,10 +80,10 @@ fn bert_masked_lm() -> failure::Fallible<()> {
#[test]
fn bert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
@ -137,10 +137,10 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
@ -187,10 +187,10 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn bert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
@ -245,10 +245,10 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;

View File

@ -10,7 +10,6 @@ use rust_bert::common::resources::{Resource, RemoteResource, download_resource};
use std::collections::HashMap;
extern crate failure;
extern crate dirs;
#[test]
fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
@ -40,14 +39,13 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
#[test]
fn distilbert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::cuda_if_available();
@ -104,10 +102,10 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
fn distilbert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model
let device = Device::cuda_if_available();
@ -153,10 +151,10 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
fn distilbert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model
let device = Device::cuda_if_available();

View File

@ -8,14 +8,14 @@ use rust_bert::common::resources::{Resource, RemoteResource, download_resource};
#[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -8,14 +8,14 @@ use rust_bert::common::resources::{RemoteResource, Resource, download_resource};
#[test]
fn gpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;
@ -70,19 +70,18 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_greedy() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 40,
do_sample: false,
num_beams: 1,
@ -90,8 +89,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
repetition_penalty: 1.1,
..Default::default()
};
let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = GPT2Generator::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None);
@ -104,19 +102,18 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -124,8 +121,7 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None);
@ -140,21 +136,18 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let device = Device::cuda_if_available();
// let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -162,8 +155,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat";
@ -182,19 +174,18 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
#[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -202,8 +193,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat was";

View File

@ -9,14 +9,14 @@ use rust_bert::common::resources::{RemoteResource, Resource, download_resource};
#[test]
fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;
@ -43,7 +43,7 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output,_, _, _, _) = openai_gpt.forward_t(
let (output, _, _, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
&None,
&None,
@ -68,18 +68,17 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 40,
do_sample: false,
num_beams: 1,
@ -88,8 +87,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
temperature: 1.1,
..Default::default()
};
let mut model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None);
@ -103,18 +101,17 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -122,8 +119,7 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None);
@ -139,18 +135,17 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -158,8 +153,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
@ -182,18 +176,17 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 20,
do_sample: false,
num_beams: 5,
@ -201,8 +194,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path,
generate_config, device)?;
let mut model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";

View File

@ -9,14 +9,14 @@ use rust_bert::common::resources::{RemoteResource, Resource, download_resource};
#[test]
fn roberta_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
let device = Device::Cpu;
@ -77,12 +77,12 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
#[test]
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
let device = Device::Cpu;
@ -136,12 +136,12 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
let device = Device::Cpu;
@ -189,12 +189,12 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn roberta_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
let device = Device::Cpu;
@ -249,12 +249,12 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
#[test]
fn roberta_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_dependency = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let merges_path = download_resource(&merges_dependency)?;
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
let device = Device::Cpu;