Refactor: Feature gate remote resource (#223)

* get_local_path as trait LocalPathProvider

* Remove config default impls

* Feature gate RemoteResource

* translation_builder refactoring to have remote fetching grouped

* Include dirs crate in remote feature gate

* Examples fixes

* Benches fixes

* Tests fix

* Remove Box from constructor parameters

* Fix examples no-Box

* Fix benches no-Box

* Fix tests no-Box

* Fix doc comment code

* Fix documentation `Resource` -> `ResourceProvider`

* moved remote local at same level

* moved ResourceProvider to resources mod

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Jonas Hedman Engström 2022-02-25 22:24:03 +01:00 committed by GitHub
parent 23c5d9112a
commit 9b22c2482a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
91 changed files with 1343 additions and 1807 deletions

View File

@ -50,8 +50,10 @@ harness = false
opt-level = 3
[features]
default = ["remote"]
doc-only = ["tch/doc-only"]
all-tests = []
remote = [ "cached-path", "dirs", "lazy_static" ]
[package.metadata.docs.rs]
features = ["doc-only"]
@ -61,14 +63,15 @@ rust_tokenizers = "~7.0.1"
tch = "~0.6.1"
serde_json = "1.0.73"
serde = { version = "1.0.132", features = ["derive"] }
dirs = "4.0.0"
ordered-float = "2.8.0"
cached-path = "0.5.1"
lazy_static = "1.4.0"
uuid = { version = "0.8.2", features = ["v4"] }
thiserror = "1.0.30"
half = "1.8.2"
cached-path = { version = "0.5.1", optional = true }
dirs = { version = "4.0.0", optional = true }
lazy_static = { version = "1.4.0", optional = true }
[dev-dependencies]
anyhow = "1.0.51"
csv = "1.1.6"

View File

@ -7,21 +7,17 @@ use rust_bert::gpt2::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use std::time::{Duration, Instant};
use tch::Device;
fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
min_length: 0,
max_length: 30,
do_sample: true,

View File

@ -7,7 +7,7 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use std::env;
use std::path::PathBuf;
use std::time::{Duration, Instant};
@ -17,11 +17,9 @@ static BATCH_SIZE: usize = 64;
fn create_qa_model() -> QuestionAnsweringModel {
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
false,
@ -54,11 +52,9 @@ fn qa_load_model(iters: u64) -> Duration {
let start = Instant::now();
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
false,

View File

@ -19,21 +19,21 @@ use rust_bert::gpt_neo::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
// Set-up model resources
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));
let generate_config = TextGenerationConfig {

View File

@ -19,21 +19,21 @@ use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGeneration
use rust_bert::reformer::{
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up model
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
let generate_config = TextGenerationConfig {

View File

@ -16,21 +16,21 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));

View File

@ -15,7 +15,7 @@ extern crate anyhow;
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -23,12 +23,9 @@ use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;

View File

@ -4,23 +4,23 @@ use rust_bert::deberta::{
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Kind, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
DebertaModelResources::DEBERTA_BASE_MNLI,
));

View File

@ -17,17 +17,15 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
false,
false,

View File

@ -20,24 +20,18 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Some(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))),
)),
false,
None,
false,

View File

@ -15,17 +15,17 @@ extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));

View File

@ -16,20 +16,20 @@ use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6,
));

View File

@ -15,17 +15,17 @@ extern crate anyhow;
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
));

View File

@ -17,17 +17,17 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
));

View File

@ -14,16 +14,14 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
fn main() -> anyhow::Result<()> {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let weights_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
let summarization_config = SummarizationConfig::new(
ModelType::T5,
weights_resource,

View File

@ -16,21 +16,15 @@ use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Load a configuration
let config = TokenClassificationConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
false,

View File

@ -18,22 +18,14 @@ use rust_bert::m2m_100::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -19,22 +19,14 @@ use rust_bert::marian::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ENGLISH2CHINESE,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ENGLISH2CHINESE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2CHINESE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2CHINESE,
));
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2CHINESE);
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2CHINESE);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2CHINESE);
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2CHINESE);
let source_languages = MarianSourceLanguages::ENGLISH2CHINESE;
let target_languages = MarianTargetLanguages::ENGLISH2CHINESE;

View File

@ -18,22 +18,16 @@ use rust_bert::mbart::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let model_resource = RemoteResource::from_pretrained(MBartModelResources::MBART50_MANY_TO_MANY);
let config_resource =
RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let merges_resource =
RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;

View File

@ -14,19 +14,15 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let source_languages = [
Language::English,

View File

@ -24,19 +24,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::AlbertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -17,10 +17,6 @@ use crate::bart::encoder::BartEncoder;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
@ -1028,43 +1024,10 @@ impl BartGenerator {
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -1293,7 +1256,7 @@ mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
resources::{RemoteResource, ResourceProvider},
Config,
};
@ -1302,8 +1265,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn bart_model_send() {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let config_resource = Box::new(RemoteResource::from_pretrained(BartConfigResources::BART));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model

View File

@ -19,22 +19,22 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! };
//! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -1215,7 +1215,7 @@ mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
resources::{RemoteResource, ResourceProvider},
Config,
};
@ -1224,8 +1224,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn bert_model_send() {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model

View File

@ -24,19 +24,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -4,8 +4,9 @@ use thiserror::Error;
#[derive(Error, Debug)]
pub enum RustBertError {
#[cfg(feature = "remote")]
#[error("Endpoint not available error: {0}")]
FileDownloadError(String),
FileDownloadError(#[from] cached_path::Error),
#[error("IO error: {0}")]
IOError(String),
@ -23,12 +24,6 @@ pub enum RustBertError {
ValueError(String),
}
impl From<cached_path::Error> for RustBertError {
fn from(error: cached_path::Error) -> Self {
RustBertError::FileDownloadError(error.to_string())
}
}
impl From<std::io::Error> for RustBertError {
fn from(error: std::io::Error) -> Self {
RustBertError::IOError(error.to_string())

View File

@ -1,197 +0,0 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This includes:
//! - model weights
//! - configuration files
//! - vocabularies
//! - (optional) merges files for BPE-based tokenizers
//!
//! These are expected in the pipelines configurations or are used as utilities to reference to the
//! resource location. Two types of resources exist:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL and a local cached file
//!
//! For both types of resources, the local location of teh file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use lazy_static::lazy_static;
use std::env;
use std::path::PathBuf;
extern crate dirs;
/// # Resource Enum pointing to model, configuration or vocabulary resources
/// Can be of type:
/// - LocalResource
/// - RemoteResource
#[derive(PartialEq, Clone)]
pub enum Resource {
Local(LocalResource),
Remote(RemoteResource),
}
impl Resource {
/// Gets the local path for a given resource.
///
/// If the resource is a remote resource, it is downloaded and cached. Then the path
/// to the local cache is returned.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let config_path = config_resource.get_local_path();
/// ```
pub fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
match self {
Resource::Local(resource) => Ok(resource.local_path.clone()),
Resource::Remote(resource) => {
let cached_path = CACHE.cached_path_with_options(
&resource.url,
&Options::default().subdir(&resource.cache_subdir),
)?;
Ok(cached_path)
}
}
}
}
/// # Local resource
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,
}
/// # Remote resource
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local subdirectory of the cache root where this resource is saved
pub cache_subdir: String,
}
impl RemoteResource {
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
/// download the resource (only declares the remote and local locations)
///
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "configs",
/// "http://config_json_location",
/// ));
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
cache_subdir: cache_subdir.to_string(),
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
///
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource { url, cache_subdir }
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `~/.cache/.rustbert`.
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
match env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = dirs::home_dir().unwrap();
home.push(".cache");
home.push(".rustbert");
home
}
}
}
#[deprecated(
since = "0.9.1",
note = "Please use `Resource.get_local_path()` instead"
)]
/// # (Download) the resource and return a path to its local path
/// This function will download remote resource to their local path if they do not exist yet.
/// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource.
/// For `LocalResource` only the resource path is returned.
///
/// # Arguments
///
/// * `resource` - Pointer to the `&Resource` to optionally download and get the local path.
///
/// # Returns
///
/// * `&PathBuf` Local path for the resource
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// let local_path = model_resource.get_local_path();
/// ```
pub fn download_resource(resource: &Resource) -> Result<PathBuf, RustBertError> {
resource.get_local_path()
}

View File

@ -0,0 +1,32 @@
use crate::common::error::RustBertError;
use crate::resources::ResourceProvider;
use std::path::PathBuf;
/// # Local resource
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,
}
impl ResourceProvider for LocalResource {
/// Gets the path for a local resource.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Ok(self.local_path.clone())
}
}

View File

@ -0,0 +1,50 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This includes:
//! - model weights
//! - configuration files
//! - vocabularies
//! - (optional) merges files for BPE-based tokenizers
//!
//! These are expected in the pipelines configurations or are used as utilities to reference to the
//! resource location. Two types of resources are pre-defined:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL
//!
//! For both types of resources, the local location of the file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
mod local;
use crate::common::error::RustBertError;
pub use local::LocalResource;
use std::path::PathBuf;
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
pub trait ResourceProvider {
/// Provides the local path for a resource.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
}
#[cfg(feature = "remote")]
mod remote;
#[cfg(feature = "remote")]
pub use remote::RemoteResource;

View File

@ -0,0 +1,122 @@
use super::*;
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use dirs::cache_dir;
use lazy_static::lazy_static;
use std::path::PathBuf;
/// # Remote resource that will be downloaded and cached locally on demand
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local subdirectory of the cache root where this resource is saved
pub cache_subdir: String,
}
impl RemoteResource {
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
/// download the resource (only declares the remote and local locations)
///
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let config_resource = RemoteResource::new(
/// "configs",
/// "http://config_json_location",
/// );
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
cache_subdir: cache_subdir.to_string(),
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
///
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let model_resource = RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// ));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource { url, cache_subdir }
}
}
impl ResourceProvider for RemoteResource {
/// Gets the local path for a remote resource.
///
/// The remote resource is downloaded and cached. Then the path
/// to the local cache is returned.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
let cached_path = CACHE
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
/// the current system.
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
match std::env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = cache_dir().unwrap();
home.push(".rustbert");
home
}
}
}

View File

@ -23,22 +23,22 @@
//! DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
//! DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
//! };
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::{RemoteResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::DeBERTaTokenizer;
//!
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let config_resource = RemoteResource::from_pretrained(
//! DebertaConfigResources::DEBERTA_BASE_MNLI,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! );
//! let vocab_resource = RemoteResource::from_pretrained(
//! DebertaVocabResources::DEBERTA_BASE_MNLI,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! );
//! let merges_resource = RemoteResource::from_pretrained(
//! DebertaMergesResources::DEBERTA_BASE_MNLI,
//! ));
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
//! );
//! let weights_resource = RemoteResource::from_pretrained(
//! DebertaModelResources::DEBERTA_BASE_MNLI,
//! ));
//! );
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -25,19 +25,19 @@
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources,
//! };
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -27,19 +27,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -1029,7 +1029,7 @@ mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
resources::{RemoteResource, ResourceProvider},
Config,
};
@ -1038,8 +1038,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn fnet_model_send() {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model

View File

@ -22,19 +22,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -19,22 +19,22 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! };
//! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -22,20 +22,20 @@
//! };
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::RemoteResource;
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoConfigResources::GPT_NEO_1_3B,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoVocabResources::GPT_NEO_1_3B,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoMergesResources::GPT_NEO_1_3B,
//! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let model_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoModelResources::GPT_NEO_1_3B,
//! ));
//!

View File

@ -27,24 +27,24 @@
//! use rust_bert::pipelines::question_answering::{
//! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
//! };
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::{RemoteResource};
//!
//! fn main() -> anyhow::Result<()> {
//! // Set-up Question Answering model
//! let config = QuestionAnsweringConfig::new(
//! ModelType::Longformer,
//! Resource::Remote(RemoteResource::from_pretrained(
//! RemoteResource::from_pretrained(
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Resource::Remote(RemoteResource::from_pretrained(
//! ),
//! RemoteResource::from_pretrained(
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Resource::Remote(RemoteResource::from_pretrained(
//! ),
//! RemoteResource::from_pretrained(
//! LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Some(Resource::Remote(RemoteResource::from_pretrained(
//! ),
//! Some(RemoteResource::from_pretrained(
//! LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
//! ))),
//! )),
//! false,
//! None,
//! false,

View File

@ -10,9 +10,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::m2m_100::decoder::M2M100Decoder;
use crate::m2m_100::encoder::M2M100Encoder;
use crate::m2m_100::LayerState;
@ -25,7 +22,6 @@ use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
@ -618,51 +614,10 @@ impl M2M100Generator {
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -889,7 +844,7 @@ mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
resources::{RemoteResource, ResourceProvider},
Config,
};
@ -898,7 +853,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let config_path = config_resource.get_local_path().expect("");

View File

@ -20,22 +20,22 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::m2m_100::{M2M100Config, M2M100Model};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::M2M100Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! };
//! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -21,22 +21,22 @@
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::MarianTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.json"),
//! });
//! let sentence_piece_resource = Resource::Local(LocalResource {
//! };
//! let sentence_piece_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?;

View File

@ -12,7 +12,6 @@
use crate::bart::BartModelOutput;
use crate::common::dropout::Dropout;
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::mbart::decoder::MBartDecoder;
use crate::mbart::encoder::MBartEncoder;
use crate::mbart::LayerState;
@ -24,7 +23,6 @@ use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
@ -839,40 +837,9 @@ impl MBartGenerator {
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<MBartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -1099,7 +1066,7 @@ mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
resources::{RemoteResource, ResourceProvider},
Config,
};
@ -1108,7 +1075,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let config_path = config_resource.get_local_path().expect("");

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::mbart::{MBartConfig, MBartModel};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::MBart50Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -24,19 +24,19 @@
//! MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
//! MobileBertModelResources, MobileBertVocabResources,
//! };
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::{RemoteResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let config_resource = RemoteResource::from_pretrained(
//! MobileBertConfigResources::MOBILEBERT_UNCASED,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! );
//! let vocab_resource = RemoteResource::from_pretrained(
//! MobileBertVocabResources::MOBILEBERT_UNCASED,
//! ));
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
//! );
//! let weights_resource = RemoteResource::from_pretrained(
//! MobileBertModelResources::MOBILEBERT_UNCASED,
//! ));
//! );
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -18,22 +18,22 @@
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::OpenAiGptTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! };
//! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -15,10 +15,7 @@
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::gpt2::Gpt2Config;
use crate::openai_gpt::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
@ -471,51 +468,10 @@ impl OpenAIGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
generate_config.validate();
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::pegasus::{PegasusConfig, PegasusModel};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::PegasusTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -12,8 +12,6 @@
use crate::bart::BartModelOutput;
use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::mbart::MBartConfig;
use crate::pegasus::decoder::PegasusDecoder;
use crate::pegasus::encoder::PegasusEncoder;
@ -601,40 +599,9 @@ impl PegasusConditionalGenerator {
pub fn new(
generate_config: GenerateConfig,
) -> Result<PegasusConditionalGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();

View File

@ -45,17 +45,21 @@
//! The authors of this repository are not responsible for any generation
//! from the 3rd party utilization of the pretrained system.
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::gpt2::GPT2Generator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::ResourceProvider;
use std::collections::HashMap;
use tch::{Device, Kind, Tensor};
use uuid::Uuid;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
/// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -63,13 +67,13 @@ pub struct ConversationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: DialoGPT-medium)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: DialoGPT-medium)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource,
pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -104,20 +108,21 @@ pub struct ConversationConfig {
pub device: Device,
}
#[cfg(feature = "remote")]
impl Default for ConversationConfig {
fn default() -> ConversationConfig {
ConversationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
merges_resource: Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
)),
min_length: 0,

View File

@ -73,10 +73,7 @@ use tch::{no_grad, Device, Tensor};
use crate::bart::LayerState as BartLayerState;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::common::resources::ResourceProvider;
use crate::gpt_neo::LayerState as GPTNeoLayerState;
use crate::pipelines::generation_utils::private_generation_utils::{
InternalGenerateOptions, PrivateLanguageGenerator,
@ -89,18 +86,24 @@ use crate::xlnet::LayerState as XLNetLayerState;
use self::ordered_float::OrderedFloat;
use crate::pipelines::common::TokenizerOption;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
extern crate ordered_float;
/// # Configuration for text generation
pub struct GenerateConfig {
/// Model weights resource (default: pretrained GPT2 model)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained GPT2 model)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Resource,
pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -133,21 +136,14 @@ pub struct GenerateConfig {
pub device: Device,
}
#[cfg(feature = "remote")]
impl Default for GenerateConfig {
fn default() -> GenerateConfig {
GenerateConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
min_length: 0,
max_length: 20,
do_sample: true,

View File

@ -78,7 +78,7 @@
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::ner::NERModel;
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::RemoteResource;
//! use rust_bert::roberta::{
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
//! };
@ -87,13 +87,13 @@
//! # fn main() -> anyhow::Result<()> {
//! let ner_config = TokenClassificationConfig {
//! model_type: ModelType::XLMRoberta,
//! model_resource: Resource::Remote(RemoteResource::from_pretrained(
//! model_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
//! )),
//! config_resource: Resource::Remote(RemoteResource::from_pretrained(
//! config_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
//! )),
//! vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
//! vocab_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
//! )),
//! lower_case: false,

View File

@ -82,16 +82,20 @@
//! To run the pipeline for another language, change the POSModel configuration from its default (see the NER pipeline for an illustration).
use crate::common::error::RustBertError;
use crate::mobilebert::{
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use crate::resources::{RemoteResource, Resource};
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
use serde::{Deserialize, Serialize};
use tch::Device;
#[cfg(feature = "remote")]
use {
crate::{
mobilebert::{
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
},
pipelines::{common::ModelType, token_classification::LabelAggregationOption},
resources::RemoteResource,
},
tch::Device,
};
#[derive(Debug, Serialize, Deserialize)]
/// # Part of Speech tag
@ -109,19 +113,20 @@ pub struct POSConfig {
token_classification_config: TokenClassificationConfig,
}
#[cfg(feature = "remote")]
impl Default for POSConfig {
/// Provides a Part of speech tagging model (English)
fn default() -> POSConfig {
POSConfig {
token_classification_config: TokenClassificationConfig {
model_type: ModelType::MobileBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
vocab_resource: Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_ENGLISH_POS,
)),
merges_resource: None,

View File

@ -46,17 +46,14 @@
use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForQuestionAnswering;
use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::distilbert::DistilBertForQuestionAnswering;
use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::reformer::ReformerForQuestionAnswering;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForQuestionAnswering;
use crate::xlnet::XLNetForQuestionAnswering;
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
@ -70,6 +67,12 @@ use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
resources::RemoteResource,
};
#[derive(Serialize, Deserialize)]
/// # Input for Question Answering
/// Includes a context (containing the answer) and question strings
@ -124,13 +127,13 @@ fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
/// Contains information regarding the model to load and device to place the model on.
pub struct QuestionAnsweringConfig {
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained DistilBERT model on SQuAD)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None)
pub merges_resource: Option<Resource>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model type
@ -157,27 +160,30 @@ impl QuestionAnsweringConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new(
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig {
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
@ -194,21 +200,21 @@ impl QuestionAnsweringConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384.
/// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new(
pub fn custom_new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
@ -216,13 +222,16 @@ impl QuestionAnsweringConfig {
doc_stride: impl Into<Option<usize>>,
max_query_length: impl Into<Option<usize>>,
max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig {
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
@ -235,16 +244,17 @@ impl QuestionAnsweringConfig {
}
}
#[cfg(feature = "remote")]
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
vocab_resource: Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
)),
merges_resource: None,

View File

@ -15,7 +15,7 @@
//!
//! ```no_run
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::{RemoteResource};
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
//! use rust_bert::pipelines::common::ModelType;
@ -23,9 +23,9 @@
//!
//! //Load a configuration
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
//! RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
//! RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
//! RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
//! None, //merges resource only relevant with ModelType::Roberta
//! true, //lowercase
//! None, //strip_accents
@ -61,17 +61,14 @@ use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::distilbert::DistilBertModelClassifier;
use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::reformer::ReformerForSequenceClassification;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification;
use rust_tokenizers::tokenizer::TruncationStrategy;
@ -82,6 +79,12 @@ use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
resources::RemoteResource,
};
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # Label generated by a `SequenceClassificationModel`
pub struct Label {
@ -102,13 +105,13 @@ pub struct SequenceClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None)
pub merges_resource: Option<Resource>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -125,27 +128,30 @@ impl SequenceClassificationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new(
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig {
) -> SequenceClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
SequenceClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
@ -154,26 +160,20 @@ impl SequenceClassificationConfig {
}
}
#[cfg(feature = "remote")]
impl Default for SequenceClassificationConfig {
/// Provides a defaultSST-2 sentiment analysis model (English)
fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type: ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
)),
merges_resource: None,
lower_case: true,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
}
SequenceClassificationConfig::new(
ModelType::DistilBert,
RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
None,
true,
None,
None,
)
}
}

View File

@ -64,17 +64,21 @@
use tch::Device;
use crate::bart::{
BartConfigResources, BartGenerator, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::bart::BartGenerator;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::pegasus::PegasusConditionalGenerator;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
/// # Configuration for text summarization
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -82,13 +86,13 @@ pub struct SummarizationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -127,45 +131,26 @@ impl SummarizationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new(
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
) -> SummarizationConfig {
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
) -> SummarizationConfig
where
R: ResourceProvider + Send + 'static,
{
SummarizationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig {
model_type: ModelType::Bart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
min_length: 56,
max_length: 142,
do_sample: false,
@ -185,6 +170,19 @@ impl Default for SummarizationConfig {
}
}
#[cfg(feature = "remote")]
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig::new(
ModelType::Bart,
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
RemoteResource::from_pretrained(BartMergesResources::BART_CNN),
)
}
}
impl From<SummarizationConfig> for GenerateConfig {
fn from(config: SummarizationConfig) -> GenerateConfig {
GenerateConfig {

View File

@ -34,19 +34,22 @@
use tch::Device;
use crate::common::error::RustBertError;
use crate::common::resources::RemoteResource;
use crate::gpt2::{
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::gpt2::GPT2Generator;
use crate::gpt_neo::GptNeoGenerator;
use crate::openai_gpt::OpenAIGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::reformer::ReformerGenerator;
use crate::resources::Resource;
use crate::resources::ResourceProvider;
use crate::xlnet::XLNetGenerator;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
/// # Configuration for text generation
/// Contains information regarding the model to load, mirrors the GenerateConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -54,13 +57,13 @@ pub struct TextGenerationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -99,45 +102,26 @@ impl TextGenerationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new(
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
) -> TextGenerationConfig {
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
) -> TextGenerationConfig
where
R: ResourceProvider + Send + 'static,
{
TextGenerationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
min_length: 0,
max_length: 20,
do_sample: true,
@ -157,6 +141,19 @@ impl Default for TextGenerationConfig {
}
}
#[cfg(feature = "remote")]
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig::new(
ModelType::GPT2,
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM),
)
}
}
impl From<TextGenerationConfig> for GenerateConfig {
fn from(config: TextGenerationConfig) -> GenerateConfig {
GenerateConfig {

View File

@ -16,17 +16,18 @@
//!
//! ```no_run
//! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
//! use rust_bert::resources::{Resource,RemoteResource};
//! use rust_bert::resources::RemoteResource;
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
//! use rust_bert::pipelines::common::ModelType;
//! # fn main() -> anyhow::Result<()> {
//!
//! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
//! let config = TokenClassificationConfig::new(ModelType::Bert,
//! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
//! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
//! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
//! let config = TokenClassificationConfig::new(
//! ModelType::Bert,
//! RemoteResource::from_pretrained(BertModelResources::BERT_NER),
//! RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
//! RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
//! None, //merges resource only relevant with ModelType::Roberta
//! false, //lowercase
//! None, //strip_accents
@ -111,11 +112,8 @@
//! ```
use crate::albert::AlbertForTokenClassification;
use crate::bert::{
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::bert::BertForTokenClassification;
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
@ -123,6 +121,7 @@ use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification;
use rust_tokenizers::tokenizer::Tokenizer;
@ -137,6 +136,12 @@ use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
#[cfg(feature = "remote")]
use crate::{
bert::{BertConfigResources, BertModelResources, BertVocabResources},
resources::RemoteResource,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel`
pub struct Token {
@ -215,13 +220,13 @@ pub struct TokenClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BERT model on CoNLL)
pub merges_resource: Option<Resource>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -242,28 +247,31 @@ impl TokenClassificationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new(
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig {
) -> TokenClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
TokenClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
@ -274,28 +282,21 @@ impl TokenClassificationConfig {
}
}
#[cfg(feature = "remote")]
impl Default for TokenClassificationConfig {
/// Provides a default CoNLL-2003 NER model (English)
fn default() -> TokenClassificationConfig {
TokenClassificationConfig {
model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
merges_resource: None,
lower_case: false,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
label_aggregation_function: LabelAggregationOption::First,
batch_size: 64,
}
TokenClassificationConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
None,
false,
None,
None,
LabelAggregationOption::First,
)
}
}

View File

@ -21,22 +21,14 @@
//! };
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::RemoteResource;
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100ModelResources::M2M100_418M,
//! ));
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100ConfigResources::M2M100_418M,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100VocabResources::M2M100_418M,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100MergesResources::M2M100_418M,
//! ));
//! let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
//! let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
//! let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
//! let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
//!
//! let source_languages = M2M100SourceLanguages::M2M100_418M;
//! let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -1,31 +1,14 @@
use crate::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use crate::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use crate::resources::{RemoteResource, Resource};
use crate::RustBertError;
use crate::pipelines::translation::Language;
use std::fmt::Debug;
use tch::Device;
struct TranslationResources {
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
source_languages: Vec<Language>,
target_languages: Vec<Language>,
}
#[cfg(feature = "remote")]
use crate::{
pipelines::translation::{TranslationConfig, TranslationModel},
resources::ResourceProvider,
RustBertError,
};
#[derive(Clone, Copy, PartialEq)]
enum ModelSize {
@ -86,21 +69,6 @@ pub struct TranslationModelBuilder {
model_size: Option<ModelSize>,
}
macro_rules! get_marian_resources {
($name:ident) => {
(
(
MarianModelResources::$name,
MarianConfigResources::$name,
MarianVocabResources::$name,
MarianSpmResources::$name,
),
MarianSourceLanguages::$name.iter().cloned().collect(),
MarianTargetLanguages::$name.iter().cloned().collect(),
)
};
}
impl Default for TranslationModelBuilder {
fn default() -> Self {
TranslationModelBuilder::new()
@ -335,29 +303,162 @@ impl TranslationModelBuilder {
self
}
fn get_default_model(
&self,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
Ok(
match self.get_marian_model(source_languages, target_languages) {
Ok(marian_resources) => marian_resources,
Err(_) => match self.model_size {
/// Creates the translation model based on the specifications provided
///
/// # Returns
/// * `TranslationModel` Generated translation model
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::translation::Language;
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
/// fn main() -> anyhow::Result<()> {
/// let model = TranslationModelBuilder::new()
/// .with_target_languages([
/// Language::Japanese,
/// Language::Korean,
/// Language::ChineseMandarin,
/// ])
/// .create_model();
/// Ok(())
/// }
/// ```
#[cfg(feature = "remote")]
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
let device = self.device.unwrap_or_else(Device::cuda_if_available);
let translation_resources = match (
&self.model_type,
&self.source_languages,
&self.target_languages,
) {
(Some(ModelType::M2M100), source_languages, target_languages) => {
match self.model_size {
Some(value) if value == ModelSize::XLarge => {
self.get_m2m100_xlarge_resources(source_languages, target_languages)?
model_fetchers::get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
_ => self.get_m2m100_large_resources(source_languages, target_languages)?,
},
},
)
_ => model_fetchers::get_m2m100_large_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
}
}
(Some(ModelType::MBart), source_languages, target_languages) => {
model_fetchers::get_mbart50_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
(Some(ModelType::Marian), source_languages, target_languages) => {
model_fetchers::get_marian_model(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
(None, source_languages, target_languages) => model_fetchers::get_default_model(
&self.model_size,
source_languages.as_ref(),
target_languages.as_ref(),
)?,
(_, None, None) | (_, _, None) | (_, None, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Source and target languages must be specified for {:?}",
self.model_type.unwrap()
)));
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
};
let translation_config = TranslationConfig::new(
translation_resources.model_type,
translation_resources.model_resource,
translation_resources.config_resource,
translation_resources.vocab_resource,
translation_resources.merges_resource,
translation_resources.source_languages,
translation_resources.target_languages,
device,
);
TranslationModel::new(translation_config)
}
}
#[cfg(feature = "remote")]
mod model_fetchers {
use super::*;
use crate::{
m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources,
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
},
marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
},
mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
},
resources::RemoteResource,
};
pub(super) struct TranslationResources<R>
where
R: ResourceProvider + Send + 'static,
{
pub(super) model_type: ModelType,
pub(super) model_resource: R,
pub(super) config_resource: R,
pub(super) vocab_resource: R,
pub(super) merges_resource: R,
pub(super) source_languages: Vec<Language>,
pub(super) target_languages: Vec<Language>,
}
fn get_marian_model(
&self,
macro_rules! get_marian_resources {
($name:ident) => {
(
(
MarianModelResources::$name,
MarianConfigResources::$name,
MarianVocabResources::$name,
MarianSpmResources::$name,
),
MarianSourceLanguages::$name.iter().cloned().collect(),
MarianTargetLanguages::$name.iter().cloned().collect(),
)
};
}
pub(super) fn get_default_model(
model_size: &Option<ModelSize>,
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
Ok(match get_marian_model(source_languages, target_languages) {
Ok(marian_resources) => marian_resources,
Err(_) => match model_size {
Some(value) if value == &ModelSize::XLarge => {
get_m2m100_xlarge_resources(source_languages, target_languages)?
}
_ => get_m2m100_large_resources(source_languages, target_languages)?,
},
})
}
pub(super) fn get_marian_model(
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
let (resources, source_languages, target_languages) =
if let (Some(source_languages), Some(target_languages)) =
(source_languages, target_languages)
@ -446,20 +547,19 @@ impl TranslationModelBuilder {
Ok(TranslationResources {
model_type: ModelType::Marian,
model_resource: Resource::Remote(RemoteResource::from_pretrained(resources.0)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(resources.1)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(resources.2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(resources.3)),
model_resource: RemoteResource::from_pretrained(resources.0),
config_resource: RemoteResource::from_pretrained(resources.1),
vocab_resource: RemoteResource::from_pretrained(resources.2),
merges_resource: RemoteResource::from_pretrained(resources.3),
source_languages,
target_languages,
})
}
fn get_mbart50_resources(
&self,
pub(super) fn get_mbart50_resources(
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.iter()
@ -488,28 +588,27 @@ impl TranslationModelBuilder {
Ok(TranslationResources {
model_type: ModelType::MBart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
),
config_resource: RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
),
vocab_resource: RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
),
merges_resource: RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
),
source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY.to_vec(),
target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY.to_vec(),
})
}
fn get_m2m100_large_resources(
&self,
pub(super) fn get_m2m100_large_resources(
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.iter()
@ -538,28 +637,19 @@ impl TranslationModelBuilder {
Ok(TranslationResources {
model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M),
config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M),
vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M),
merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M),
source_languages: M2M100SourceLanguages::M2M100_418M.to_vec(),
target_languages: M2M100TargetLanguages::M2M100_418M.to_vec(),
})
}
fn get_m2m100_xlarge_resources(
&self,
pub(super) fn get_m2m100_xlarge_resources(
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> {
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages {
if !source_languages
.iter()
@ -588,97 +678,12 @@ impl TranslationModelBuilder {
Ok(TranslationResources {
model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_1_2B,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_1_2B,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_1_2B,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_1_2B,
)),
model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_1_2B),
config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_1_2B),
vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_1_2B),
merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_1_2B),
source_languages: M2M100SourceLanguages::M2M100_1_2B.to_vec(),
target_languages: M2M100TargetLanguages::M2M100_1_2B.to_vec(),
})
}
/// Creates the translation model based on the specifications provided
///
/// # Returns
/// * `TranslationModel` Generated translation model
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::translation::Language;
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
/// fn main() -> anyhow::Result<()> {
/// let model = TranslationModelBuilder::new()
/// .with_target_languages([
/// Language::Japanese,
/// Language::Korean,
/// Language::ChineseMandarin,
/// ])
/// .create_model();
/// Ok(())
/// }
/// ```
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
let device = self.device.unwrap_or_else(Device::cuda_if_available);
let translation_resources = match (
&self.model_type,
&self.source_languages,
&self.target_languages,
) {
(Some(ModelType::M2M100), source_languages, target_languages) => {
match self.model_size {
Some(value) if value == ModelSize::XLarge => self.get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
_ => self.get_m2m100_large_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
}
}
(Some(ModelType::MBart), source_languages, target_languages) => {
self.get_mbart50_resources(source_languages.as_ref(), target_languages.as_ref())?
}
(Some(ModelType::Marian), source_languages, target_languages) => {
self.get_marian_model(source_languages.as_ref(), target_languages.as_ref())?
}
(None, source_languages, target_languages) => {
self.get_default_model(source_languages.as_ref(), target_languages.as_ref())?
}
(_, None, None) | (_, _, None) | (_, None, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Source and target languages must be specified for {:?}",
self.model_type.unwrap()
)));
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
};
let translation_config = TranslationConfig::new(
translation_resources.model_type,
translation_resources.model_resource,
translation_resources.config_resource,
translation_resources.vocab_resource,
translation_resources.merges_resource,
translation_resources.source_languages,
translation_resources.target_languages,
device,
);
TranslationModel::new(translation_config)
}
}

View File

@ -14,13 +14,13 @@
use tch::Device;
use crate::common::error::RustBertError;
use crate::common::resources::Resource;
use crate::m2m_100::M2M100Generator;
use crate::marian::MarianGenerator;
use crate::mbart::MBartGenerator;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
@ -374,13 +374,13 @@ pub struct TranslationConfig {
/// Model type used for translation
pub model_type: ModelType,
/// Model weights resource
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource
pub merges_resource: Resource,
pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Supported source languages
pub source_languages: HashSet<Language>,
/// Supported target languages
@ -435,18 +435,18 @@ impl TranslationConfig {
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::{RemoteResource, Resource};
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
/// );
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -464,17 +464,18 @@ impl TranslationConfig {
/// # Ok(())
/// # }
/// ```
pub fn new<S, T>(
pub fn new<R, S, T>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>,
) -> TranslationConfig
where
R: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
{
@ -482,10 +483,10 @@ impl TranslationConfig {
TranslationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
source_languages: source_languages.as_ref().iter().cloned().collect(),
target_languages: target_languages.as_ref().iter().cloned().collect(),
device,
@ -798,18 +799,18 @@ impl TranslationModel {
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
/// );
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -859,21 +860,21 @@ impl TranslationModel {
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ENGLISH2ROMANCE,
/// ));
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ENGLISH2ROMANCE,
/// ));
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ENGLISH2ROMANCE,
/// ));
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
/// );
/// let merges_resource = RemoteResource::from_pretrained(
/// MarianSpmResources::ENGLISH2ROMANCE,
/// ));
/// );
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
///
@ -938,15 +939,10 @@ mod test {
#[test]
#[ignore] // no need to run, compilation is enough to verify it is Send
fn test() {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ROMANCE2ENGLISH,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ROMANCE2ENGLISH,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ROMANCE2ENGLISH,
));
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
let config_resource =
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;

View File

@ -99,10 +99,7 @@
//! ```
use crate::albert::AlbertForSequenceClassification;
use crate::bart::{
BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources,
BartVocabResources,
};
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
@ -110,7 +107,7 @@ use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::sequence_classification::Label;
use crate::resources::{RemoteResource, Resource};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification;
use crate::RustBertError;
@ -122,19 +119,25 @@ use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
/// # Configuration for ZeroShotClassificationModel
/// Contains information regarding the model to load and device to place the model on.
pub struct ZeroShotClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource,
pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource,
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None)
pub merges_resource: Option<Resource>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -151,27 +154,30 @@ impl ZeroShotClassificationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new(
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig {
) -> ZeroShotClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
ZeroShotClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
@ -180,21 +186,22 @@ impl ZeroShotClassificationConfig {
}
}
#[cfg(feature = "remote")]
impl Default for ZeroShotClassificationConfig {
/// Provides a defaultSST-2 sentiment analysis model (English)
fn default() -> ZeroShotClassificationConfig {
ZeroShotClassificationConfig {
model_type: ModelType::Bart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_MNLI,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Box::new(RemoteResource::from_pretrained(
BartConfigResources::BART_MNLI,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
vocab_resource: Box::new(RemoteResource::from_pretrained(
BartVocabResources::BART_MNLI,
)),
merges_resource: Some(Resource::Remote(RemoteResource::from_pretrained(
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
BartMergesResources::BART_MNLI,
))),
lower_case: false,

View File

@ -20,17 +20,17 @@
//! use rust_bert::prophetnet::{
//! ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
//! };
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::RemoteResource;
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
//! ));
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let weights_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
//! ));
//!

View File

@ -18,8 +18,6 @@ use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
use serde::{Deserialize, Serialize};
use tch::{nn, Kind, Tensor};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
@ -909,40 +907,9 @@ impl ProphetNetConditionalGenerator {
pub fn new(
generate_config: GenerateConfig,
) -> Result<ProphetNetConditionalGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::reformer::{ReformerConfig, ReformerModel};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::ReformerTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/weights.ot"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;

View File

@ -23,8 +23,6 @@ use tch::{nn, Device, Kind, Tensor};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
@ -1019,40 +1017,9 @@ pub struct ReformerGenerator {
impl ReformerGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<ReformerGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();

View File

@ -23,23 +23,23 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig;
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! };
//! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! };
//! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/merges.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;

View File

@ -19,20 +19,20 @@
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::T5Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let sentence_piece_resource = Resource::Local(LocalResource {
//! };
//! let sentence_piece_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! };
//! let config_path = config_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;

View File

@ -18,8 +18,6 @@ use serde::{Deserialize, Serialize};
use tch::nn::embedding;
use tch::{nn, Tensor};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
@ -715,34 +713,9 @@ pub struct T5Generator {
impl T5Generator {
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();

View File

@ -22,18 +22,18 @@
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::resources::RemoteResource;
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetConfigResources::XLNET_BASE_CASED,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
//! let model_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetModelResources::XLNET_BASE_CASED,
//! ));
//! let generate_config = TextGenerationConfig {

View File

@ -6,7 +6,7 @@ use rust_bert::albert::{
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -16,13 +16,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn albert_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
@ -87,10 +87,10 @@ fn albert_masked_lm() -> anyhow::Result<()> {
#[test]
fn albert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
@ -153,10 +153,10 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn albert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
@ -219,10 +219,10 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn albert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
@ -286,10 +286,10 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
#[test]
fn albert_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;

View File

@ -6,7 +6,7 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
#[test]
fn bart_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6,
));
let config_path = config_resource.get_local_path()?;
@ -77,16 +77,16 @@ fn bart_lm_model() -> anyhow::Result<()> {
#[test]
fn bart_summarization_greedy() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6,
));
let summarization_config = SummarizationConfig {
@ -138,16 +138,16 @@ about exoplanets like K2-18b."];
#[test]
fn bart_summarization_beam_search() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6,
));
let summarization_config = SummarizationConfig {

View File

@ -11,7 +11,7 @@ use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -21,12 +21,9 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn bert_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
@ -106,10 +103,8 @@ fn bert_masked_lm() -> anyhow::Result<()> {
#[test]
fn bert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -170,10 +165,8 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn bert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -230,10 +223,8 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn bert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -295,10 +286,8 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
#[test]
fn bert_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -422,11 +411,9 @@ fn bert_question_answering() -> anyhow::Result<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
false,
false,

View File

@ -3,7 +3,7 @@ use rust_bert::deberta::{
DebertaForSequenceClassification, DebertaForTokenClassification, DebertaMergesResources,
DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use std::collections::HashMap;
@ -14,16 +14,16 @@ extern crate anyhow;
#[test]
fn deberta_natural_language_inference() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
DebertaModelResources::DEBERTA_BASE_MNLI,
));
@ -87,7 +87,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
#[test]
fn deberta_masked_lm() -> anyhow::Result<()> {
// Set-up masked LM model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let config_path = config_resource.get_local_path()?;
@ -142,13 +142,13 @@ fn deberta_masked_lm() -> anyhow::Result<()> {
#[test]
fn deberta_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI,
));
let config_path = config_resource.get_local_path()?;
@ -203,13 +203,13 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
#[test]
fn deberta_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI,
));
let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::distilbert::{
};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -42,13 +42,13 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
#[test]
fn distilbert_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = config_resource.get_local_path()?;
@ -123,10 +123,10 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
#[test]
fn distilbert_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let config_path = config_resource.get_local_path()?;
@ -188,10 +188,10 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
#[test]
fn distilbert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let config_path = config_resource.get_local_path()?;

View File

@ -3,7 +3,7 @@ use rust_bert::gpt2::{
Gpt2VocabResources,
};
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -11,16 +11,16 @@ use tch::{nn, Device, Tensor};
#[test]
fn distilgpt2_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let config_path = config_resource.get_local_path()?;

View File

@ -2,7 +2,7 @@ use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
ElectraModelResources, ElectraVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -11,13 +11,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn electra_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = config_resource.get_local_path()?;
@ -95,13 +95,13 @@ fn electra_masked_lm() -> anyhow::Result<()> {
#[test]
fn electra_discriminator() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = config_resource.get_local_path()?;

View File

@ -7,7 +7,7 @@ use rust_bert::fnet::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -17,12 +17,9 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn fnet_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetModelResources::BASE));
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let weights_resource = Box::new(RemoteResource::from_pretrained(FNetModelResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
@ -85,13 +82,13 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
#[test]
fn fnet_for_sequence_classification() -> anyhow::Result<()> {
// Set up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
@ -128,10 +125,8 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn fnet_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -188,10 +183,8 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn fnet_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
@ -251,10 +244,8 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
#[test]
fn fnet_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;

View File

@ -10,7 +10,7 @@ use rust_bert::pipelines::generation_utils::{
Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator,
};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -18,14 +18,10 @@ use tch::{nn, Device, Tensor};
#[test]
fn gpt2_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2);
let vocab_resource = RemoteResource::from_pretrained(Gpt2VocabResources::GPT2);
let merges_resource = RemoteResource::from_pretrained(Gpt2MergesResources::GPT2);
let weights_resource = RemoteResource::from_pretrained(Gpt2ModelResources::GPT2);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -114,14 +110,10 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
#[test]
fn gpt2_generation_greedy() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
@ -150,14 +142,10 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
#[test]
fn gpt2_generation_beam_search() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
@ -198,14 +186,10 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
#[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
@ -259,14 +243,10 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
#[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
@ -319,14 +299,10 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
#[test]
fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
@ -381,14 +357,10 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
#[test]
fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628];
@ -450,14 +422,10 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
#[test]
fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 36,
@ -520,14 +488,10 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
#[test]
fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 36,
@ -590,14 +554,10 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
#[test]
fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
// Resources definition
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628];

View File

@ -4,7 +4,7 @@ use rust_bert::gpt_neo::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -12,16 +12,16 @@ use tch::{nn, Device, Tensor};
#[test]
fn gpt_neo_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));
let config_path = config_resource.get_local_path()?;
@ -109,16 +109,16 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
#[test]
fn test_generation_gpt_neo() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));

View File

@ -11,7 +11,7 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
@ -21,18 +21,14 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn longformer_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_4096,
));
let config_resource =
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
let vocab_resource =
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
let merges_resource =
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let weights_resource =
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_4096);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -176,15 +172,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
#[test]
fn longformer_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_resource =
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
let vocab_resource =
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
let merges_resource =
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -245,15 +238,12 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn longformer_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_resource =
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
let vocab_resource =
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
let merges_resource =
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -321,15 +311,12 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_resource =
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
let vocab_resource =
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
let merges_resource =
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -394,18 +381,12 @@ fn longformer_for_question_answering() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Some(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))),
)),
false,
None,
false,

View File

@ -4,7 +4,7 @@ use rust_bert::m2m_100::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -12,18 +12,10 @@ use tch::{nn, Device, Tensor};
#[test]
fn m2m100_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
let weights_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -76,18 +68,10 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
#[test]
fn m2m100_translation() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -6,25 +6,17 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> anyhow::Result<()> {
// Set-up translation model
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ENGLISH2ROMANCE,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ENGLISH2ROMANCE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2ROMANCE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2ROMANCE,
));
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;

View File

@ -3,7 +3,7 @@ use rust_bert::mbart::{
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -11,13 +11,13 @@ use tch::{nn, Device, Tensor};
#[test]
fn mbart_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
));
let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::mobilebert::{
MobileBertModelResources, MobileBertVocabResources,
};
use rust_bert::pipelines::pos_tagging::POSModel;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn mobilebert_masked_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
@ -111,10 +111,10 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
#[test]
fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
@ -162,10 +162,10 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
@ -220,10 +220,10 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
@ -273,10 +273,10 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
#[test]
fn mobilebert_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;

View File

@ -6,7 +6,7 @@ use rust_bert::openai_gpt::{
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
#[test]
fn openai_gpt_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = config_resource.get_local_path()?;
@ -104,16 +104,16 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
#[test]
fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
@ -146,16 +146,16 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
#[test]
fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
@ -199,16 +199,16 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
@ -268,16 +268,16 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));

View File

@ -2,19 +2,19 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
fn pegasus_summarization_greedy() -> anyhow::Result<()> {
// Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
));

View File

@ -4,19 +4,19 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
// Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
));

View File

@ -4,7 +4,7 @@ use rust_bert::reformer::{
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
ReformerForSequenceClassification, ReformerModelResources, ReformerVocabResources,
};
use rust_bert::resources::{LocalResource, RemoteResource, Resource};
use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
use std::collections::HashMap;
@ -17,7 +17,7 @@ use tch::{nn, no_grad, Device, Tensor};
fn test_generation_reformer() -> anyhow::Result<()> {
// ===================================================
// Modify resource to enforce seed
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
@ -31,18 +31,18 @@ fn test_generation_reformer() -> anyhow::Result<()> {
let _ = updated_config_file.write_all(serde_json::to_string(&config).unwrap().as_bytes());
let updated_config_path = updated_config_file.into_temp_path();
let config_resource = Resource::Local(LocalResource {
let config_resource = Box::new(LocalResource {
local_path: updated_config_path.to_path_buf(),
});
// ===================================================
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
// Set-up translation model
@ -79,10 +79,10 @@ fn test_generation_reformer() -> anyhow::Result<()> {
#[test]
fn reformer_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let config_path = config_resource.get_local_path()?;
@ -145,10 +145,10 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn reformer_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::pipelines::token_classification::TokenClassificationConfig;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources,
@ -20,18 +20,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test]
fn roberta_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::DISTILROBERTA_BASE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::DISTILROBERTA_BASE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::DISTILROBERTA_BASE,
));
let config_resource =
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let merges_resource =
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
let weights_resource =
RemoteResource::from_pretrained(RobertaModelResources::DISTILROBERTA_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -116,15 +111,11 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
#[test]
fn roberta_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::DISTILROBERTA_BASE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::DISTILROBERTA_BASE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_resource =
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let merges_resource =
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -190,15 +181,11 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn roberta_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::DISTILROBERTA_BASE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::DISTILROBERTA_BASE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_resource =
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let merges_resource =
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -260,15 +247,11 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn roberta_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::DISTILROBERTA_BASE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::DISTILROBERTA_BASE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_resource =
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let merges_resource =
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
@ -337,18 +320,12 @@ fn roberta_question_answering() -> anyhow::Result<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig::new(
ModelType::Roberta,
Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RemoteResource::from_pretrained(RobertaModelResources::ROBERTA_QA),
RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA_QA),
RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA_QA),
Some(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
))),
)),
false,
None,
false,
@ -378,13 +355,13 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
// Set-up question answering model
let ner_config = TokenClassificationConfig {
model_type: ModelType::XLMRoberta,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
model_resource: Box::new(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Box::new(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
vocab_resource: Box::new(RemoteResource::from_pretrained(
RobertaVocabResources::XLM_ROBERTA_NER_DE,
)),
lower_case: false,

View File

@ -1,20 +1,16 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
#[test]
fn test_translation_t5() -> anyhow::Result<()> {
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let source_languages = [
Language::English,
@ -70,18 +66,10 @@ fn test_summarization_t5() -> anyhow::Result<()> {
// Set-up translation model
let summarization_config = SummarizationConfig {
model_type: ModelType::T5,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
T5ModelResources::T5_SMALL,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
T5ConfigResources::T5_SMALL,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
T5VocabResources::T5_SMALL,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
T5VocabResources::T5_SMALL,
)),
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
min_length: 30,
max_length: 200,
early_stopping: true,

View File

@ -1,6 +1,6 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::xlnet::{
XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel,
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Kind, Tensor};
#[test]
fn xlnet_base_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
@ -122,13 +122,13 @@ fn xlnet_base_model() -> anyhow::Result<()> {
#[test]
fn xlnet_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let weights_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
@ -196,16 +196,16 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
#[test]
fn xlnet_generation_beam_search() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
@ -239,10 +239,10 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
#[test]
fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
@ -311,10 +311,10 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
#[test]
fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
@ -379,10 +379,10 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
#[test]
fn xlnet_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
@ -442,10 +442,10 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
#[test]
fn xlnet_for_question_answering() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;