Refactor: Feature gate remote resource (#223)

* get_local_path as trait LocalPathProvider

* Remove config default impls

* Feature gate RemoteResource

* translation_builder refactoring to have remote fetching grouped

* Include dirs crate in remote feature gate

* Examples fixes

* Benches fixes

* Tests fix

* Remove Box from constructor parameters

* Fix examples no-Box

* Fix benches no-Box

* Fix tests no-Box

* Fix doc comment code

* Fix documentation `Resource` -> `ResourceProvider`

* moved remote local at same level

* moved ResourceProvider to resources mod

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Jonas Hedman Engström 2022-02-25 22:24:03 +01:00 committed by GitHub
parent 23c5d9112a
commit 9b22c2482a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
91 changed files with 1343 additions and 1807 deletions

View File

@ -50,8 +50,10 @@ harness = false
opt-level = 3 opt-level = 3
[features] [features]
default = ["remote"]
doc-only = ["tch/doc-only"] doc-only = ["tch/doc-only"]
all-tests = [] all-tests = []
remote = [ "cached-path", "dirs", "lazy_static" ]
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["doc-only"] features = ["doc-only"]
@ -61,14 +63,15 @@ rust_tokenizers = "~7.0.1"
tch = "~0.6.1" tch = "~0.6.1"
serde_json = "1.0.73" serde_json = "1.0.73"
serde = { version = "1.0.132", features = ["derive"] } serde = { version = "1.0.132", features = ["derive"] }
dirs = "4.0.0"
ordered-float = "2.8.0" ordered-float = "2.8.0"
cached-path = "0.5.1"
lazy_static = "1.4.0"
uuid = { version = "0.8.2", features = ["v4"] } uuid = { version = "0.8.2", features = ["v4"] }
thiserror = "1.0.30" thiserror = "1.0.30"
half = "1.8.2" half = "1.8.2"
cached-path = { version = "0.5.1", optional = true }
dirs = { version = "4.0.0", optional = true }
lazy_static = { version = "1.4.0", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = "1.0.51" anyhow = "1.0.51"
csv = "1.1.6" csv = "1.1.6"

View File

@ -7,21 +7,17 @@ use rust_bert::gpt2::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tch::Device; use tch::Device;
fn create_text_generation_model() -> TextGenerationModel { fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig { let config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)), model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
Gpt2ConfigResources::GPT2, vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
)), merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
min_length: 0, min_length: 0,
max_length: 30, max_length: 30,
do_sample: true, do_sample: true,

View File

@ -7,7 +7,7 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{ use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -17,11 +17,9 @@ static BATCH_SIZE: usize = 64;
fn create_qa_model() -> QuestionAnsweringModel { fn create_qa_model() -> QuestionAnsweringModel {
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)), RemoteResource::from_pretrained(BertModelResources::BERT_QA),
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
BertConfigResources::BERT_QA, RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta None, //merges resource only relevant with ModelType::Roberta
false, //lowercase false, //lowercase
false, false,
@ -54,11 +52,9 @@ fn qa_load_model(iters: u64) -> Duration {
let start = Instant::now(); let start = Instant::now();
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)), RemoteResource::from_pretrained(BertModelResources::BERT_QA),
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
BertConfigResources::BERT_QA, RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta None, //merges resource only relevant with ModelType::Roberta
false, //lowercase false, //lowercase
false, false,

View File

@ -19,21 +19,21 @@ use rust_bert::gpt_neo::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set-up model resources // Set-up model resources
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M, GptNeoConfigResources::GPT_NEO_125M,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M, GptNeoVocabResources::GPT_NEO_125M,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M, GptNeoMergesResources::GPT_NEO_125M,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M, GptNeoModelResources::GPT_NEO_125M,
)); ));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {

View File

@ -19,21 +19,21 @@ use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGeneration
use rust_bert::reformer::{ use rust_bert::reformer::{
ReformerConfigResources, ReformerModelResources, ReformerVocabResources, ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set-up model // Set-up model
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT, ReformerConfigResources::CRIME_AND_PUNISHMENT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT, ReformerModelResources::CRIME_AND_PUNISHMENT,
)); ));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {

View File

@ -16,21 +16,21 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources}; use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));

View File

@ -15,7 +15,7 @@ extern crate anyhow;
use rust_bert::bert::{ use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources, BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -23,12 +23,9 @@ use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource = let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?; let weights_path = weights_resource.get_local_path()?;

View File

@ -4,23 +4,23 @@ use rust_bert::deberta::{
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification, DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
DebertaMergesResources, DebertaModelResources, DebertaVocabResources, DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Kind, Tensor}; use tch::{nn, no_grad, Device, Kind, Tensor};
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI, DebertaConfigResources::DEBERTA_BASE_MNLI,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI, DebertaVocabResources::DEBERTA_BASE_MNLI,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI, DebertaMergesResources::DEBERTA_BASE_MNLI,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
DebertaModelResources::DEBERTA_BASE_MNLI, DebertaModelResources::DEBERTA_BASE_MNLI,
)); ));

View File

@ -17,17 +17,15 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set-up Question Answering model // Set-up Question Answering model
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)), RemoteResource::from_pretrained(BertModelResources::BERT_QA),
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
BertConfigResources::BERT_QA, RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta None, //merges resource only relevant with ModelType::Roberta
false, false,
false, false,

View File

@ -20,24 +20,18 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set-up Question Answering model // Set-up Question Answering model
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Longformer, ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
LongformerModelResources::LONGFORMER_BASE_SQUAD1, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
)), RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Resource::Remote(RemoteResource::from_pretrained( Some(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1, LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))), )),
false, false,
None, None,
false, false,

View File

@ -15,17 +15,17 @@ extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources}; use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel}; use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set-up classifier // Set-up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2, FNetConfigResources::BASE_SST2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2, FNetVocabResources::BASE_SST2,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2, FNetModelResources::BASE_SST2,
)); ));

View File

@ -16,20 +16,20 @@ use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources, BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
}; };
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6, BartConfigResources::DISTILBART_CNN_6_6,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6, BartVocabResources::DISTILBART_CNN_6_6,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6, BartMergesResources::DISTILBART_CNN_6_6,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6, BartModelResources::DISTILBART_CNN_6_6,
)); ));

View File

@ -15,17 +15,17 @@ extern crate anyhow;
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources}; use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL, PegasusConfigResources::CNN_DAILYMAIL,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL, PegasusVocabResources::CNN_DAILYMAIL,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL, PegasusModelResources::CNN_DAILYMAIL,
)); ));

View File

@ -17,17 +17,17 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::prophetnet::{ use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM, ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM, ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM, ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
)); ));

View File

@ -14,16 +14,14 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let config_resource = let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let vocab_resource = let weights_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let summarization_config = SummarizationConfig::new( let summarization_config = SummarizationConfig::new(
ModelType::T5, ModelType::T5,
weights_resource, weights_resource,

View File

@ -16,21 +16,15 @@ use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{ use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, LabelAggregationOption, TokenClassificationConfig,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Load a configuration // Load a configuration
let config = TokenClassificationConfig::new( let config = TokenClassificationConfig::new(
ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertModelResources::BERT_NER),
BertModelResources::BERT_NER, RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
)), RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta None, //merges resource only relevant with ModelType::Roberta
false, //lowercase false, //lowercase
false, false,

View File

@ -18,22 +18,14 @@ use rust_bert::m2m_100::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
M2M100ModelResources::M2M100_418M, let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
)); let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let source_languages = M2M100SourceLanguages::M2M100_418M; let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M; let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -19,22 +19,14 @@ use rust_bert::marian::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2CHINESE);
MarianModelResources::ENGLISH2CHINESE, let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2CHINESE);
)); let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2CHINESE);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2CHINESE);
MarianConfigResources::ENGLISH2CHINESE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2CHINESE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2CHINESE,
));
let source_languages = MarianSourceLanguages::ENGLISH2CHINESE; let source_languages = MarianSourceLanguages::ENGLISH2CHINESE;
let target_languages = MarianTargetLanguages::ENGLISH2CHINESE; let target_languages = MarianTargetLanguages::ENGLISH2CHINESE;

View File

@ -18,22 +18,16 @@ use rust_bert::mbart::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(MBartModelResources::MBART50_MANY_TO_MANY);
MBartModelResources::MBART50_MANY_TO_MANY, let config_resource =
)); RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
MBartConfigResources::MBART50_MANY_TO_MANY, let merges_resource =
)); RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY; let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY; let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;

View File

@ -14,19 +14,15 @@ extern crate anyhow;
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device; use tch::Device;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let model_resource = let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE)); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
let config_resource = let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE)); let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let source_languages = [ let source_languages = [
Language::English, Language::English,

View File

@ -24,19 +24,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM}; //! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::AlbertTokenizer; //! use rust_tokenizers::tokenizer::AlbertTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -17,10 +17,6 @@ use crate::bart::encoder::BartEncoder;
use crate::common::activations::Activation; use crate::common::activations::Activation;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity; use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator, PreparedInput, PrivateLanguageGenerator,
@ -1028,43 +1024,10 @@ impl BartGenerator {
/// # } /// # }
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let merges_path = generate_config.merges_resource.get_local_path()?;
{ let weights_path = generate_config.model_resource.get_local_path()?;
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();
@ -1293,7 +1256,7 @@ mod test {
use tch::Device; use tch::Device;
use crate::{ use crate::{
resources::{RemoteResource, Resource}, resources::{RemoteResource, ResourceProvider},
Config, Config,
}; };
@ -1302,8 +1265,7 @@ mod test {
#[test] #[test]
#[ignore] // compilation is enough, no need to run #[ignore] // compilation is enough, no need to run
fn bart_model_send() { fn bart_model_send() {
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(BartConfigResources::BART));
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let config_path = config_resource.get_local_path().expect(""); let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model // Set-up masked LM model

View File

@ -19,22 +19,22 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel}; //! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::RobertaTokenizer; //! use rust_tokenizers::tokenizer::RobertaTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let merges_resource = Resource::Local(LocalResource { //! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -1215,7 +1215,7 @@ mod test {
use tch::Device; use tch::Device;
use crate::{ use crate::{
resources::{RemoteResource, Resource}, resources::{RemoteResource, ResourceProvider},
Config, Config,
}; };
@ -1224,8 +1224,7 @@ mod test {
#[test] #[test]
#[ignore] // compilation is enough, no need to run #[ignore] // compilation is enough, no need to run
fn bert_model_send() { fn bert_model_send() {
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_path = config_resource.get_local_path().expect(""); let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model // Set-up masked LM model

View File

@ -24,19 +24,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bert::{BertConfig, BertForMaskedLM}; //! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer; //! use rust_tokenizers::tokenizer::BertTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -4,8 +4,9 @@ use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum RustBertError { pub enum RustBertError {
#[cfg(feature = "remote")]
#[error("Endpoint not available error: {0}")] #[error("Endpoint not available error: {0}")]
FileDownloadError(String), FileDownloadError(#[from] cached_path::Error),
#[error("IO error: {0}")] #[error("IO error: {0}")]
IOError(String), IOError(String),
@ -23,12 +24,6 @@ pub enum RustBertError {
ValueError(String), ValueError(String),
} }
impl From<cached_path::Error> for RustBertError {
fn from(error: cached_path::Error) -> Self {
RustBertError::FileDownloadError(error.to_string())
}
}
impl From<std::io::Error> for RustBertError { impl From<std::io::Error> for RustBertError {
fn from(error: std::io::Error) -> Self { fn from(error: std::io::Error) -> Self {
RustBertError::IOError(error.to_string()) RustBertError::IOError(error.to_string())

View File

@ -1,197 +0,0 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This includes:
//! - model weights
//! - configuration files
//! - vocabularies
//! - (optional) merges files for BPE-based tokenizers
//!
//! These are expected in the pipelines configurations or are used as utilities to reference to the
//! resource location. Two types of resources exist:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL and a local cached file
//!
//! For both types of resources, the local location of teh file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use lazy_static::lazy_static;
use std::env;
use std::path::PathBuf;
extern crate dirs;
/// # Resource Enum pointing to model, configuration or vocabulary resources
/// Can be of type:
/// - LocalResource
/// - RemoteResource
#[derive(PartialEq, Clone)]
pub enum Resource {
Local(LocalResource),
Remote(RemoteResource),
}
impl Resource {
/// Gets the local path for a given resource.
///
/// If the resource is a remote resource, it is downloaded and cached. Then the path
/// to the local cache is returned.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let config_path = config_resource.get_local_path();
/// ```
pub fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
match self {
Resource::Local(resource) => Ok(resource.local_path.clone()),
Resource::Remote(resource) => {
let cached_path = CACHE.cached_path_with_options(
&resource.url,
&Options::default().subdir(&resource.cache_subdir),
)?;
Ok(cached_path)
}
}
}
}
/// # Local resource
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,
}
/// # Remote resource
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local subdirectory of the cache root where this resource is saved
pub cache_subdir: String,
}
impl RemoteResource {
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
/// download the resource (only declares the remote and local locations)
///
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "configs",
/// "http://config_json_location",
/// ));
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
cache_subdir: cache_subdir.to_string(),
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
///
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource { url, cache_subdir }
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `~/.cache/.rustbert`.
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
match env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = dirs::home_dir().unwrap();
home.push(".cache");
home.push(".rustbert");
home
}
}
}
#[deprecated(
since = "0.9.1",
note = "Please use `Resource.get_local_path()` instead"
)]
/// # (Download) the resource and return a path to its local path
/// This function will download remote resource to their local path if they do not exist yet.
/// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource.
/// For `LocalResource` only the resource path is returned.
///
/// # Arguments
///
/// * `resource` - Pointer to the `&Resource` to optionally download and get the local path.
///
/// # Returns
///
/// * `&PathBuf` Local path for the resource
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// let local_path = model_resource.get_local_path();
/// ```
pub fn download_resource(resource: &Resource) -> Result<PathBuf, RustBertError> {
resource.get_local_path()
}

View File

@ -0,0 +1,32 @@
use crate::common::error::RustBertError;
use crate::resources::ResourceProvider;
use std::path::PathBuf;
/// # Local resource
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,
}
impl ResourceProvider for LocalResource {
/// Gets the path for a local resource.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Ok(self.local_path.clone())
}
}

View File

@ -0,0 +1,50 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This includes:
//! - model weights
//! - configuration files
//! - vocabularies
//! - (optional) merges files for BPE-based tokenizers
//!
//! These are expected in the pipelines configurations or are used as utilities to reference to the
//! resource location. Two types of resources are pre-defined:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL
//!
//! For both types of resources, the local location of the file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
mod local;
use crate::common::error::RustBertError;
pub use local::LocalResource;
use std::path::PathBuf;
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
pub trait ResourceProvider {
/// Provides the local path for a resource.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
}
#[cfg(feature = "remote")]
mod remote;
#[cfg(feature = "remote")]
pub use remote::RemoteResource;

View File

@ -0,0 +1,122 @@
use super::*;
use crate::common::error::RustBertError;
use cached_path::{Cache, Options, ProgressBar};
use dirs::cache_dir;
use lazy_static::lazy_static;
use std::path::PathBuf;
/// # Remote resource that will be downloaded and cached locally on demand
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local subdirectory of the cache root where this resource is saved
pub cache_subdir: String,
}
impl RemoteResource {
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
/// download the resource (only declares the remote and local locations)
///
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let config_resource = RemoteResource::new(
/// "configs",
/// "http://config_json_location",
/// );
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
cache_subdir: cache_subdir.to_string(),
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
///
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let model_resource = RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// ));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource { url, cache_subdir }
}
}
impl ResourceProvider for RemoteResource {
/// Gets the local path for a remote resource.
///
/// The remote resource is downloaded and cached. Then the path
/// to the local cache is returned.
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
let cached_path = CACHE
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
/// the current system.
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
match std::env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = cache_dir().unwrap();
home.push(".rustbert");
home
}
}
}

View File

@ -23,22 +23,22 @@
//! DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification, //! DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
//! DebertaMergesResources, DebertaModelResources, DebertaVocabResources, //! DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
//! }; //! };
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::{RemoteResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::DeBERTaTokenizer; //! use rust_tokenizers::tokenizer::DeBERTaTokenizer;
//! //!
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let config_resource = RemoteResource::from_pretrained(
//! DebertaConfigResources::DEBERTA_BASE_MNLI, //! DebertaConfigResources::DEBERTA_BASE_MNLI,
//! )); //! );
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( //! let vocab_resource = RemoteResource::from_pretrained(
//! DebertaVocabResources::DEBERTA_BASE_MNLI, //! DebertaVocabResources::DEBERTA_BASE_MNLI,
//! )); //! );
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained( //! let merges_resource = RemoteResource::from_pretrained(
//! DebertaMergesResources::DEBERTA_BASE_MNLI, //! DebertaMergesResources::DEBERTA_BASE_MNLI,
//! )); //! );
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained( //! let weights_resource = RemoteResource::from_pretrained(
//! DebertaModelResources::DEBERTA_BASE_MNLI, //! DebertaModelResources::DEBERTA_BASE_MNLI,
//! )); //! );
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -25,19 +25,19 @@
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, //! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources, //! DistilBertModelResources, DistilBertVocabResources,
//! }; //! };
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer; //! use rust_tokenizers::tokenizer::BertTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -27,19 +27,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM}; //! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer; //! use rust_tokenizers::tokenizer::BertTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -1029,7 +1029,7 @@ mod test {
use tch::Device; use tch::Device;
use crate::{ use crate::{
resources::{RemoteResource, Resource}, resources::{RemoteResource, ResourceProvider},
Config, Config,
}; };
@ -1038,8 +1038,7 @@ mod test {
#[test] #[test]
#[ignore] // compilation is enough, no need to run #[ignore] // compilation is enough, no need to run
fn fnet_model_send() { fn fnet_model_send() {
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let config_path = config_resource.get_local_path().expect(""); let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model // Set-up masked LM model

View File

@ -22,19 +22,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM}; //! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer}; //! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -19,22 +19,22 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config}; //! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer; //! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let merges_resource = Resource::Local(LocalResource { //! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -22,20 +22,20 @@
//! }; //! };
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; //! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::RemoteResource;
//! use tch::Device; //! use tch::Device;
//! //!
//! fn main() -> anyhow::Result<()> { //! fn main() -> anyhow::Result<()> {
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let config_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoConfigResources::GPT_NEO_1_3B, //! GptNeoConfigResources::GPT_NEO_1_3B,
//! )); //! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( //! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoVocabResources::GPT_NEO_1_3B, //! GptNeoVocabResources::GPT_NEO_1_3B,
//! )); //! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained( //! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoMergesResources::GPT_NEO_1_3B, //! GptNeoMergesResources::GPT_NEO_1_3B,
//! )); //! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained( //! let model_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoModelResources::GPT_NEO_1_3B, //! GptNeoModelResources::GPT_NEO_1_3B,
//! )); //! ));
//! //!

View File

@ -27,24 +27,24 @@
//! use rust_bert::pipelines::question_answering::{ //! use rust_bert::pipelines::question_answering::{
//! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, //! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
//! }; //! };
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::{RemoteResource};
//! //!
//! fn main() -> anyhow::Result<()> { //! fn main() -> anyhow::Result<()> {
//! // Set-up Question Answering model //! // Set-up Question Answering model
//! let config = QuestionAnsweringConfig::new( //! let config = QuestionAnsweringConfig::new(
//! ModelType::Longformer, //! ModelType::Longformer,
//! Resource::Remote(RemoteResource::from_pretrained( //! RemoteResource::from_pretrained(
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1, //! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
//! )), //! ),
//! Resource::Remote(RemoteResource::from_pretrained( //! RemoteResource::from_pretrained(
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1, //! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
//! )), //! ),
//! Resource::Remote(RemoteResource::from_pretrained( //! RemoteResource::from_pretrained(
//! LongformerVocabResources::LONGFORMER_BASE_SQUAD1, //! LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
//! )), //! ),
//! Some(Resource::Remote(RemoteResource::from_pretrained( //! Some(RemoteResource::from_pretrained(
//! LongformerMergesResources::LONGFORMER_BASE_SQUAD1, //! LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
//! ))), //! )),
//! false, //! false,
//! None, //! None,
//! false, //! false,

View File

@ -10,9 +10,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::m2m_100::decoder::M2M100Decoder; use crate::m2m_100::decoder::M2M100Decoder;
use crate::m2m_100::encoder::M2M100Encoder; use crate::m2m_100::encoder::M2M100Encoder;
use crate::m2m_100::LayerState; use crate::m2m_100::LayerState;
@ -25,7 +22,6 @@ use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
}; };
use crate::pipelines::translation::Language; use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::{Config, RustBertError}; use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{M2M100Vocab, Vocab}; use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
@ -618,51 +614,10 @@ impl M2M100Generator {
/// # } /// # }
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let merges_path = generate_config.merges_resource.get_local_path()?;
{ let weights_path = generate_config.model_resource.get_local_path()?;
Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();
@ -889,7 +844,7 @@ mod test {
use tch::Device; use tch::Device;
use crate::{ use crate::{
resources::{RemoteResource, Resource}, resources::{RemoteResource, ResourceProvider},
Config, Config,
}; };
@ -898,7 +853,7 @@ mod test {
#[test] #[test]
#[ignore] // compilation is enough, no need to run #[ignore] // compilation is enough, no need to run
fn mbart_model_send() { fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M, M2M100ConfigResources::M2M100_418M,
)); ));
let config_path = config_resource.get_local_path().expect(""); let config_path = config_resource.get_local_path().expect("");

View File

@ -20,22 +20,22 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::m2m_100::{M2M100Config, M2M100Model}; //! use rust_bert::m2m_100::{M2M100Config, M2M100Model};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::M2M100Tokenizer; //! use rust_tokenizers::tokenizer::M2M100Tokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let merges_resource = Resource::Local(LocalResource { //! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -21,22 +21,22 @@
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel}; //! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::marian::MarianForConditionalGeneration; //! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::MarianTokenizer; //! use rust_tokenizers::tokenizer::MarianTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.json"), //! local_path: PathBuf::from("path/to/vocab.json"),
//! }); //! };
//! let sentence_piece_resource = Resource::Local(LocalResource { //! let sentence_piece_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?; //! let spiece_path = sentence_piece_resource.get_local_path()?;

View File

@ -12,7 +12,6 @@
use crate::bart::BartModelOutput; use crate::bart::BartModelOutput;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::mbart::decoder::MBartDecoder; use crate::mbart::decoder::MBartDecoder;
use crate::mbart::encoder::MBartEncoder; use crate::mbart::encoder::MBartEncoder;
use crate::mbart::LayerState; use crate::mbart::LayerState;
@ -24,7 +23,6 @@ use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
}; };
use crate::pipelines::translation::Language; use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::{Activation, Config, RustBertError}; use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{MBart50Vocab, Vocab}; use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
@ -839,40 +837,9 @@ impl MBartGenerator {
/// # } /// # }
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<MBartGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<MBartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let weights_path = generate_config.model_resource.get_local_path()?;
{
Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();
@ -1099,7 +1066,7 @@ mod test {
use tch::Device; use tch::Device;
use crate::{ use crate::{
resources::{RemoteResource, Resource}, resources::{RemoteResource, ResourceProvider},
Config, Config,
}; };
@ -1108,7 +1075,7 @@ mod test {
#[test] #[test]
#[ignore] // compilation is enough, no need to run #[ignore] // compilation is enough, no need to run
fn mbart_model_send() { fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY, MBartConfigResources::MBART50_MANY_TO_MANY,
)); ));
let config_path = config_resource.get_local_path().expect(""); let config_path = config_resource.get_local_path().expect("");

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::mbart::{MBartConfig, MBartModel}; //! use rust_bert::mbart::{MBartConfig, MBartModel};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::MBart50Tokenizer; //! use rust_tokenizers::tokenizer::MBart50Tokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -24,19 +24,19 @@
//! MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM, //! MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
//! MobileBertModelResources, MobileBertVocabResources, //! MobileBertModelResources, MobileBertVocabResources,
//! }; //! };
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::{RemoteResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer; //! use rust_tokenizers::tokenizer::BertTokenizer;
//! //!
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let config_resource = RemoteResource::from_pretrained(
//! MobileBertConfigResources::MOBILEBERT_UNCASED, //! MobileBertConfigResources::MOBILEBERT_UNCASED,
//! )); //! );
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( //! let vocab_resource = RemoteResource::from_pretrained(
//! MobileBertVocabResources::MOBILEBERT_UNCASED, //! MobileBertVocabResources::MOBILEBERT_UNCASED,
//! )); //! );
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained( //! let weights_resource = RemoteResource::from_pretrained(
//! MobileBertModelResources::MOBILEBERT_UNCASED, //! MobileBertModelResources::MOBILEBERT_UNCASED,
//! )); //! );
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -18,22 +18,22 @@
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::gpt2::Gpt2Config; //! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel; //! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::OpenAiGptTokenizer; //! use rust_tokenizers::tokenizer::OpenAiGptTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let merges_resource = Resource::Local(LocalResource { //! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -15,10 +15,7 @@
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair; use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias}; use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::common::resources::{RemoteResource, Resource}; use crate::gpt2::Gpt2Config;
use crate::gpt2::{
Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::openai_gpt::transformer::Block; use crate::openai_gpt::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
@ -471,51 +468,10 @@ impl OpenAIGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
generate_config.validate(); generate_config.validate();
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let merges_path = generate_config.merges_resource.get_local_path()?;
{ let weights_path = generate_config.model_resource.get_local_path()?;
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::pegasus::{PegasusConfig, PegasusModel}; //! use rust_bert::pegasus::{PegasusConfig, PegasusModel};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::PegasusTokenizer; //! use rust_tokenizers::tokenizer::PegasusTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -12,8 +12,6 @@
use crate::bart::BartModelOutput; use crate::bart::BartModelOutput;
use crate::common::kind::get_negative_infinity; use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::mbart::MBartConfig; use crate::mbart::MBartConfig;
use crate::pegasus::decoder::PegasusDecoder; use crate::pegasus::decoder::PegasusDecoder;
use crate::pegasus::encoder::PegasusEncoder; use crate::pegasus::encoder::PegasusEncoder;
@ -601,40 +599,9 @@ impl PegasusConditionalGenerator {
pub fn new( pub fn new(
generate_config: GenerateConfig, generate_config: GenerateConfig,
) -> Result<PegasusConditionalGenerator, RustBertError> { ) -> Result<PegasusConditionalGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let weights_path = generate_config.model_resource.get_local_path()?;
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();

View File

@ -45,17 +45,21 @@
//! The authors of this repository are not responsible for any generation //! The authors of this repository are not responsible for any generation
//! from the 3rd party utilization of the pretrained system. //! from the 3rd party utilization of the pretrained system.
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource}; use crate::gpt2::GPT2Generator;
use crate::gpt2::{
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::ResourceProvider;
use std::collections::HashMap; use std::collections::HashMap;
use tch::{Device, Kind, Tensor}; use tch::{Device, Kind, Tensor};
use uuid::Uuid; use uuid::Uuid;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
/// # Configuration for multi-turn classification /// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on. /// different set of default parameters and sets the device to place the model on.
@ -63,13 +67,13 @@ pub struct ConversationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: DialoGPT-medium) /// Model weights resource (default: DialoGPT-medium)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: DialoGPT-medium) /// Config resource (default: DialoGPT-medium)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: DialoGPT-medium) /// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: DialoGPT-medium) /// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource, pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -104,20 +108,21 @@ pub struct ConversationConfig {
pub device: Device, pub device: Device,
} }
#[cfg(feature = "remote")]
impl Default for ConversationConfig { impl Default for ConversationConfig {
fn default() -> ConversationConfig { fn default() -> ConversationConfig {
ConversationConfig { ConversationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM, Gpt2ModelResources::DIALOGPT_MEDIUM,
)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM, Gpt2ConfigResources::DIALOGPT_MEDIUM,
)), )),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM, Gpt2VocabResources::DIALOGPT_MEDIUM,
)), )),
merges_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM, Gpt2MergesResources::DIALOGPT_MEDIUM,
)), )),
min_length: 0, min_length: 0,

View File

@ -73,10 +73,7 @@ use tch::{no_grad, Device, Tensor};
use crate::bart::LayerState as BartLayerState; use crate::bart::LayerState as BartLayerState;
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource}; use crate::common::resources::ResourceProvider;
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::gpt_neo::LayerState as GPTNeoLayerState; use crate::gpt_neo::LayerState as GPTNeoLayerState;
use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::private_generation_utils::{
InternalGenerateOptions, PrivateLanguageGenerator, InternalGenerateOptions, PrivateLanguageGenerator,
@ -89,18 +86,24 @@ use crate::xlnet::LayerState as XLNetLayerState;
use self::ordered_float::OrderedFloat; use self::ordered_float::OrderedFloat;
use crate::pipelines::common::TokenizerOption; use crate::pipelines::common::TokenizerOption;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
extern crate ordered_float; extern crate ordered_float;
/// # Configuration for text generation /// # Configuration for text generation
pub struct GenerateConfig { pub struct GenerateConfig {
/// Model weights resource (default: pretrained GPT2 model) /// Model weights resource (default: pretrained GPT2 model)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained GPT2 model) /// Config resource (default: pretrained GPT2 model)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained GPT2 model) /// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model) /// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Resource, pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -133,21 +136,14 @@ pub struct GenerateConfig {
pub device: Device, pub device: Device,
} }
#[cfg(feature = "remote")]
impl Default for GenerateConfig { impl Default for GenerateConfig {
fn default() -> GenerateConfig { fn default() -> GenerateConfig {
GenerateConfig { GenerateConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
Gpt2ModelResources::GPT2, config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
)), vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
Gpt2ConfigResources::GPT2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
min_length: 0, min_length: 0,
max_length: 20, max_length: 20,
do_sample: true, do_sample: true,

View File

@ -78,7 +78,7 @@
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::ner::NERModel; //! use rust_bert::pipelines::ner::NERModel;
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig; //! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::RemoteResource;
//! use rust_bert::roberta::{ //! use rust_bert::roberta::{
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources, //! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
//! }; //! };
@ -87,13 +87,13 @@
//! # fn main() -> anyhow::Result<()> { //! # fn main() -> anyhow::Result<()> {
//! let ner_config = TokenClassificationConfig { //! let ner_config = TokenClassificationConfig {
//! model_type: ModelType::XLMRoberta, //! model_type: ModelType::XLMRoberta,
//! model_resource: Resource::Remote(RemoteResource::from_pretrained( //! model_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaModelResources::XLM_ROBERTA_NER_DE, //! RobertaModelResources::XLM_ROBERTA_NER_DE,
//! )), //! )),
//! config_resource: Resource::Remote(RemoteResource::from_pretrained( //! config_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaConfigResources::XLM_ROBERTA_NER_DE, //! RobertaConfigResources::XLM_ROBERTA_NER_DE,
//! )), //! )),
//! vocab_resource: Resource::Remote(RemoteResource::from_pretrained( //! vocab_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaVocabResources::XLM_ROBERTA_NER_DE, //! RobertaVocabResources::XLM_ROBERTA_NER_DE,
//! )), //! )),
//! lower_case: false, //! lower_case: false,

View File

@ -82,16 +82,20 @@
//! To run the pipeline for another language, change the POSModel configuration from its default (see the NER pipeline for an illustration). //! To run the pipeline for another language, change the POSModel configuration from its default (see the NER pipeline for an illustration).
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::mobilebert::{ use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
};
use crate::pipelines::common::ModelType;
use crate::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use crate::resources::{RemoteResource, Resource};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tch::Device;
#[cfg(feature = "remote")]
use {
crate::{
mobilebert::{
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
},
pipelines::{common::ModelType, token_classification::LabelAggregationOption},
resources::RemoteResource,
},
tch::Device,
};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
/// # Part of Speech tag /// # Part of Speech tag
@ -109,19 +113,20 @@ pub struct POSConfig {
token_classification_config: TokenClassificationConfig, token_classification_config: TokenClassificationConfig,
} }
#[cfg(feature = "remote")]
impl Default for POSConfig { impl Default for POSConfig {
/// Provides a Part of speech tagging model (English) /// Provides a Part of speech tagging model (English)
fn default() -> POSConfig { fn default() -> POSConfig {
POSConfig { POSConfig {
token_classification_config: TokenClassificationConfig { token_classification_config: TokenClassificationConfig {
model_type: ModelType::MobileBert, model_type: ModelType::MobileBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_ENGLISH_POS, MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS, MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
)), )),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_ENGLISH_POS, MobileBertVocabResources::MOBILEBERT_ENGLISH_POS,
)), )),
merges_resource: None, merges_resource: None,

View File

@ -46,17 +46,14 @@
use crate::albert::AlbertForQuestionAnswering; use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering; use crate::bert::BertForQuestionAnswering;
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForQuestionAnswering; use crate::deberta::DebertaForQuestionAnswering;
use crate::distilbert::{ use crate::distilbert::DistilBertForQuestionAnswering;
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForQuestionAnswering; use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering; use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering; use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::reformer::ReformerForQuestionAnswering; use crate::reformer::ReformerForQuestionAnswering;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForQuestionAnswering; use crate::roberta::RobertaForQuestionAnswering;
use crate::xlnet::XLNetForQuestionAnswering; use crate::xlnet::XLNetForQuestionAnswering;
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput}; use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
@ -70,6 +67,12 @@ use tch::kind::Kind::Float;
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor}; use tch::{nn, no_grad, Device, Tensor};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
resources::RemoteResource,
};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
/// # Input for Question Answering /// # Input for Question Answering
/// Includes a context (containing the answer) and question strings /// Includes a context (containing the answer) and question strings
@ -124,13 +127,13 @@ fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
/// Contains information regarding the model to load and device to place the model on. /// Contains information regarding the model to load and device to place the model on.
pub struct QuestionAnsweringConfig { pub struct QuestionAnsweringConfig {
/// Model weights resource (default: pretrained DistilBERT model on SQuAD) /// Model weights resource (default: pretrained DistilBERT model on SQuAD)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained DistilBERT model on SQuAD) /// Config resource (default: pretrained DistilBERT model on SQuAD)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained DistilBERT model on SQuAD) /// Vocab resource (default: pretrained DistilBERT model on SQuAD)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None) /// Merges resource (default: None)
pub merges_resource: Option<Resource>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Device to place the model on (default: CUDA/GPU when available) /// Device to place the model on (default: CUDA/GPU when available)
pub device: Device, pub device: Device,
/// Model type /// Model type
@ -157,27 +160,30 @@ impl QuestionAnsweringConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot) /// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Option<Resource>, merges_resource: Option<R>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig { ) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case, lower_case,
strip_accents: strip_accents.into(), strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(), add_prefix_space: add_prefix_space.into(),
@ -194,21 +200,21 @@ impl QuestionAnsweringConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot) /// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384. /// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384.
/// * max_query_length - Optional maximum question token length. Defaults to 64. /// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128. /// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15. /// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new( pub fn custom_new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Option<Resource>, merges_resource: Option<R>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
@ -216,13 +222,16 @@ impl QuestionAnsweringConfig {
doc_stride: impl Into<Option<usize>>, doc_stride: impl Into<Option<usize>>,
max_query_length: impl Into<Option<usize>>, max_query_length: impl Into<Option<usize>>,
max_answer_length: impl Into<Option<usize>>, max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig { ) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case, lower_case,
strip_accents: strip_accents.into(), strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(), add_prefix_space: add_prefix_space.into(),
@ -235,16 +244,17 @@ impl QuestionAnsweringConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for QuestionAnsweringConfig { impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig { fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD, DistilBertModelResources::DISTIL_BERT_SQUAD,
)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD, DistilBertConfigResources::DISTIL_BERT_SQUAD,
)), )),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD, DistilBertVocabResources::DISTIL_BERT_SQUAD,
)), )),
merges_resource: None, merges_resource: None,

View File

@ -15,7 +15,7 @@
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig; //! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::{RemoteResource};
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources}; //! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; //! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
@ -23,9 +23,9 @@
//! //!
//! //Load a configuration //! //Load a configuration
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert, //! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)), //! RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)), //! RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)), //! RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
//! None, //merges resource only relevant with ModelType::Roberta //! None, //merges resource only relevant with ModelType::Roberta
//! true, //lowercase //! true, //lowercase
//! None, //strip_accents //! None, //strip_accents
@ -61,17 +61,14 @@ use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification; use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification; use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForSequenceClassification; use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::{ use crate::distilbert::DistilBertModelClassifier;
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::fnet::FNetForSequenceClassification; use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification; use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::reformer::ReformerForSequenceClassification; use crate::reformer::ReformerForSequenceClassification;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification; use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification; use crate::xlnet::XLNetForSequenceClassification;
use rust_tokenizers::tokenizer::TruncationStrategy; use rust_tokenizers::tokenizer::TruncationStrategy;
@ -82,6 +79,12 @@ use std::collections::HashMap;
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor}; use tch::{nn, no_grad, Device, Kind, Tensor};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
resources::RemoteResource,
};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
/// # Label generated by a `SequenceClassificationModel` /// # Label generated by a `SequenceClassificationModel`
pub struct Label { pub struct Label {
@ -102,13 +105,13 @@ pub struct SequenceClassificationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL) /// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL) /// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL) /// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None) /// Merges resource (default: None)
pub merges_resource: Option<Resource>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model) /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool, pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -125,27 +128,30 @@ impl SequenceClassificationConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot) /// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Option<Resource>, merges_resource: Option<R>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig { ) -> SequenceClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
SequenceClassificationConfig { SequenceClassificationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case, lower_case,
strip_accents: strip_accents.into(), strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(), add_prefix_space: add_prefix_space.into(),
@ -154,26 +160,20 @@ impl SequenceClassificationConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for SequenceClassificationConfig { impl Default for SequenceClassificationConfig {
/// Provides a defaultSST-2 sentiment analysis model (English) /// Provides a defaultSST-2 sentiment analysis model (English)
fn default() -> SequenceClassificationConfig { fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig { SequenceClassificationConfig::new(
model_type: ModelType::DistilBert, ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
DistilBertModelResources::DISTIL_BERT_SST2, RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
)), RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
config_resource: Resource::Remote(RemoteResource::from_pretrained( None,
DistilBertConfigResources::DISTIL_BERT_SST2, true,
)), None,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( None,
DistilBertVocabResources::DISTIL_BERT_SST2, )
)),
merges_resource: None,
lower_case: true,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
}
} }
} }

View File

@ -64,17 +64,21 @@
use tch::Device; use tch::Device;
use crate::bart::{ use crate::bart::BartGenerator;
BartConfigResources, BartGenerator, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::pegasus::PegasusConditionalGenerator; use crate::pegasus::PegasusConditionalGenerator;
use crate::pipelines::common::ModelType; use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::prophetnet::ProphetNetConditionalGenerator; use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator; use crate::t5::T5Generator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
/// # Configuration for text summarization /// # Configuration for text summarization
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on. /// different set of default parameters and sets the device to place the model on.
@ -82,13 +86,13 @@ pub struct SummarizationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM) /// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BART model on CNN-DM) /// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM) /// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM) /// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource, pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -127,45 +131,26 @@ impl SummarizationConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot) /// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Resource, merges_resource: R,
) -> SummarizationConfig { ) -> SummarizationConfig
where
R: ResourceProvider + Send + 'static,
{
SummarizationConfig { SummarizationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: Box::new(merges_resource),
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig {
model_type: ModelType::Bart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
min_length: 56, min_length: 56,
max_length: 142, max_length: 142,
do_sample: false, do_sample: false,
@ -185,6 +170,19 @@ impl Default for SummarizationConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig::new(
ModelType::Bart,
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
RemoteResource::from_pretrained(BartMergesResources::BART_CNN),
)
}
}
impl From<SummarizationConfig> for GenerateConfig { impl From<SummarizationConfig> for GenerateConfig {
fn from(config: SummarizationConfig) -> GenerateConfig { fn from(config: SummarizationConfig) -> GenerateConfig {
GenerateConfig { GenerateConfig {

View File

@ -34,19 +34,22 @@
use tch::Device; use tch::Device;
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::RemoteResource; use crate::gpt2::GPT2Generator;
use crate::gpt2::{
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::gpt_neo::GptNeoGenerator; use crate::gpt_neo::GptNeoGenerator;
use crate::openai_gpt::OpenAIGenerator; use crate::openai_gpt::OpenAIGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator}; use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::reformer::ReformerGenerator; use crate::reformer::ReformerGenerator;
use crate::resources::Resource; use crate::resources::ResourceProvider;
use crate::xlnet::XLNetGenerator; use crate::xlnet::XLNetGenerator;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
/// # Configuration for text generation /// # Configuration for text generation
/// Contains information regarding the model to load, mirrors the GenerateConfig, with a /// Contains information regarding the model to load, mirrors the GenerateConfig, with a
/// different set of default parameters and sets the device to place the model on. /// different set of default parameters and sets the device to place the model on.
@ -54,13 +57,13 @@ pub struct TextGenerationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM) /// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BART model on CNN-DM) /// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM) /// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM) /// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource, pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -99,45 +102,26 @@ impl TextGenerationConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot) /// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Resource, merges_resource: R,
) -> TextGenerationConfig { ) -> TextGenerationConfig
where
R: ResourceProvider + Send + 'static,
{
TextGenerationConfig { TextGenerationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: Box::new(merges_resource),
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
min_length: 0, min_length: 0,
max_length: 20, max_length: 20,
do_sample: true, do_sample: true,
@ -157,6 +141,19 @@ impl Default for TextGenerationConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig::new(
ModelType::GPT2,
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM),
)
}
}
impl From<TextGenerationConfig> for GenerateConfig { impl From<TextGenerationConfig> for GenerateConfig {
fn from(config: TextGenerationConfig) -> GenerateConfig { fn from(config: TextGenerationConfig) -> GenerateConfig {
GenerateConfig { GenerateConfig {

View File

@ -16,17 +16,18 @@
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig}; //! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
//! use rust_bert::resources::{Resource,RemoteResource}; //! use rust_bert::resources::RemoteResource;
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources}; //! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//! # fn main() -> anyhow::Result<()> { //! # fn main() -> anyhow::Result<()> {
//! //!
//! //Load a configuration //! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption; //! use rust_bert::pipelines::token_classification::LabelAggregationOption;
//! let config = TokenClassificationConfig::new(ModelType::Bert, //! let config = TokenClassificationConfig::new(
//! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), //! ModelType::Bert,
//! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), //! RemoteResource::from_pretrained(BertModelResources::BERT_NER),
//! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), //! RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
//! RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
//! None, //merges resource only relevant with ModelType::Roberta //! None, //merges resource only relevant with ModelType::Roberta
//! false, //lowercase //! false, //lowercase
//! None, //strip_accents //! None, //strip_accents
@ -111,11 +112,8 @@
//! ``` //! ```
use crate::albert::AlbertForTokenClassification; use crate::albert::AlbertForTokenClassification;
use crate::bert::{ use crate::bert::BertForTokenClassification;
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::deberta::DebertaForTokenClassification; use crate::deberta::DebertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification; use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification; use crate::electra::ElectraForTokenClassification;
@ -123,6 +121,7 @@ use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification; use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification; use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification; use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification; use crate::xlnet::XLNetForTokenClassification;
use rust_tokenizers::tokenizer::Tokenizer; use rust_tokenizers::tokenizer::Tokenizer;
@ -137,6 +136,12 @@ use std::collections::HashMap;
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor}; use tch::{nn, no_grad, Device, Kind, Tensor};
#[cfg(feature = "remote")]
use crate::{
bert::{BertConfigResources, BertModelResources, BertVocabResources},
resources::RemoteResource,
};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel` /// # Token generated by a `TokenClassificationModel`
pub struct Token { pub struct Token {
@ -215,13 +220,13 @@ pub struct TokenClassificationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL) /// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL) /// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL) /// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BERT model on CoNLL) /// Merges resource (default: pretrained BERT model on CoNLL)
pub merges_resource: Option<Resource>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model) /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool, pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -242,28 +247,31 @@ impl TokenClassificationConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot) /// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Option<Resource>, merges_resource: Option<R>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
label_aggregation_function: LabelAggregationOption, label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig { ) -> TokenClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
TokenClassificationConfig { TokenClassificationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case, lower_case,
strip_accents: strip_accents.into(), strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(), add_prefix_space: add_prefix_space.into(),
@ -274,28 +282,21 @@ impl TokenClassificationConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for TokenClassificationConfig { impl Default for TokenClassificationConfig {
/// Provides a default CoNLL-2003 NER model (English) /// Provides a default CoNLL-2003 NER model (English)
fn default() -> TokenClassificationConfig { fn default() -> TokenClassificationConfig {
TokenClassificationConfig { TokenClassificationConfig::new(
model_type: ModelType::Bert, ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertModelResources::BERT_NER),
BertModelResources::BERT_NER, RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
)), RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
config_resource: Resource::Remote(RemoteResource::from_pretrained( None,
BertConfigResources::BERT_NER, false,
)), None,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( None,
BertVocabResources::BERT_NER, LabelAggregationOption::First,
)), )
merges_resource: None,
lower_case: false,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
label_aggregation_function: LabelAggregationOption::First,
batch_size: 64,
}
} }
} }

View File

@ -21,22 +21,14 @@
//! }; //! };
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; //! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::RemoteResource;
//! use tch::Device; //! use tch::Device;
//! //!
//! fn main() -> anyhow::Result<()> { //! fn main() -> anyhow::Result<()> {
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained( //! let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
//! M2M100ModelResources::M2M100_418M, //! let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
//! )); //! let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
//! M2M100ConfigResources::M2M100_418M,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100VocabResources::M2M100_418M,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! M2M100MergesResources::M2M100_418M,
//! ));
//! //!
//! let source_languages = M2M100SourceLanguages::M2M100_418M; //! let source_languages = M2M100SourceLanguages::M2M100_418M;
//! let target_languages = M2M100TargetLanguages::M2M100_418M; //! let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -1,31 +1,14 @@
use crate::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use crate::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use crate::pipelines::common::ModelType; use crate::pipelines::common::ModelType;
use crate::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::RustBertError;
use std::fmt::Debug; use std::fmt::Debug;
use tch::Device; use tch::Device;
struct TranslationResources { #[cfg(feature = "remote")]
model_type: ModelType, use crate::{
model_resource: Resource, pipelines::translation::{TranslationConfig, TranslationModel},
config_resource: Resource, resources::ResourceProvider,
vocab_resource: Resource, RustBertError,
merges_resource: Resource, };
source_languages: Vec<Language>,
target_languages: Vec<Language>,
}
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
enum ModelSize { enum ModelSize {
@ -86,21 +69,6 @@ pub struct TranslationModelBuilder {
model_size: Option<ModelSize>, model_size: Option<ModelSize>,
} }
macro_rules! get_marian_resources {
($name:ident) => {
(
(
MarianModelResources::$name,
MarianConfigResources::$name,
MarianVocabResources::$name,
MarianSpmResources::$name,
),
MarianSourceLanguages::$name.iter().cloned().collect(),
MarianTargetLanguages::$name.iter().cloned().collect(),
)
};
}
impl Default for TranslationModelBuilder { impl Default for TranslationModelBuilder {
fn default() -> Self { fn default() -> Self {
TranslationModelBuilder::new() TranslationModelBuilder::new()
@ -335,29 +303,162 @@ impl TranslationModelBuilder {
self self
} }
fn get_default_model( /// Creates the translation model based on the specifications provided
&self, ///
source_languages: Option<&Vec<Language>>, /// # Returns
target_languages: Option<&Vec<Language>>, /// * `TranslationModel` Generated translation model
) -> Result<TranslationResources, RustBertError> { ///
Ok( /// # Example
match self.get_marian_model(source_languages, target_languages) { ///
Ok(marian_resources) => marian_resources, /// ```no_run
Err(_) => match self.model_size { /// use rust_bert::pipelines::translation::Language;
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
/// fn main() -> anyhow::Result<()> {
/// let model = TranslationModelBuilder::new()
/// .with_target_languages([
/// Language::Japanese,
/// Language::Korean,
/// Language::ChineseMandarin,
/// ])
/// .create_model();
/// Ok(())
/// }
/// ```
#[cfg(feature = "remote")]
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
let device = self.device.unwrap_or_else(Device::cuda_if_available);
let translation_resources = match (
&self.model_type,
&self.source_languages,
&self.target_languages,
) {
(Some(ModelType::M2M100), source_languages, target_languages) => {
match self.model_size {
Some(value) if value == ModelSize::XLarge => { Some(value) if value == ModelSize::XLarge => {
self.get_m2m100_xlarge_resources(source_languages, target_languages)? model_fetchers::get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?
} }
_ => self.get_m2m100_large_resources(source_languages, target_languages)?, _ => model_fetchers::get_m2m100_large_resources(
}, source_languages.as_ref(),
}, target_languages.as_ref(),
) )?,
}
}
(Some(ModelType::MBart), source_languages, target_languages) => {
model_fetchers::get_mbart50_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
(Some(ModelType::Marian), source_languages, target_languages) => {
model_fetchers::get_marian_model(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
(None, source_languages, target_languages) => model_fetchers::get_default_model(
&self.model_size,
source_languages.as_ref(),
target_languages.as_ref(),
)?,
(_, None, None) | (_, _, None) | (_, None, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Source and target languages must be specified for {:?}",
self.model_type.unwrap()
)));
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
};
let translation_config = TranslationConfig::new(
translation_resources.model_type,
translation_resources.model_resource,
translation_resources.config_resource,
translation_resources.vocab_resource,
translation_resources.merges_resource,
translation_resources.source_languages,
translation_resources.target_languages,
device,
);
TranslationModel::new(translation_config)
}
}
#[cfg(feature = "remote")]
mod model_fetchers {
use super::*;
use crate::{
m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources,
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
},
marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
},
mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
},
resources::RemoteResource,
};
pub(super) struct TranslationResources<R>
where
R: ResourceProvider + Send + 'static,
{
pub(super) model_type: ModelType,
pub(super) model_resource: R,
pub(super) config_resource: R,
pub(super) vocab_resource: R,
pub(super) merges_resource: R,
pub(super) source_languages: Vec<Language>,
pub(super) target_languages: Vec<Language>,
} }
fn get_marian_model( macro_rules! get_marian_resources {
&self, ($name:ident) => {
(
(
MarianModelResources::$name,
MarianConfigResources::$name,
MarianVocabResources::$name,
MarianSpmResources::$name,
),
MarianSourceLanguages::$name.iter().cloned().collect(),
MarianTargetLanguages::$name.iter().cloned().collect(),
)
};
}
pub(super) fn get_default_model(
model_size: &Option<ModelSize>,
source_languages: Option<&Vec<Language>>, source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>, target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> { ) -> Result<TranslationResources<RemoteResource>, RustBertError> {
Ok(match get_marian_model(source_languages, target_languages) {
Ok(marian_resources) => marian_resources,
Err(_) => match model_size {
Some(value) if value == &ModelSize::XLarge => {
get_m2m100_xlarge_resources(source_languages, target_languages)?
}
_ => get_m2m100_large_resources(source_languages, target_languages)?,
},
})
}
pub(super) fn get_marian_model(
source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
let (resources, source_languages, target_languages) = let (resources, source_languages, target_languages) =
if let (Some(source_languages), Some(target_languages)) = if let (Some(source_languages), Some(target_languages)) =
(source_languages, target_languages) (source_languages, target_languages)
@ -446,20 +547,19 @@ impl TranslationModelBuilder {
Ok(TranslationResources { Ok(TranslationResources {
model_type: ModelType::Marian, model_type: ModelType::Marian,
model_resource: Resource::Remote(RemoteResource::from_pretrained(resources.0)), model_resource: RemoteResource::from_pretrained(resources.0),
config_resource: Resource::Remote(RemoteResource::from_pretrained(resources.1)), config_resource: RemoteResource::from_pretrained(resources.1),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(resources.2)), vocab_resource: RemoteResource::from_pretrained(resources.2),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(resources.3)), merges_resource: RemoteResource::from_pretrained(resources.3),
source_languages, source_languages,
target_languages, target_languages,
}) })
} }
fn get_mbart50_resources( pub(super) fn get_mbart50_resources(
&self,
source_languages: Option<&Vec<Language>>, source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>, target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> { ) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages { if let Some(source_languages) = source_languages {
if !source_languages if !source_languages
.iter() .iter()
@ -488,28 +588,27 @@ impl TranslationModelBuilder {
Ok(TranslationResources { Ok(TranslationResources {
model_type: ModelType::MBart, model_type: ModelType::MBart,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY, MBartModelResources::MBART50_MANY_TO_MANY,
)), ),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY, MBartConfigResources::MBART50_MANY_TO_MANY,
)), ),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY, MBartVocabResources::MBART50_MANY_TO_MANY,
)), ),
merges_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY, MBartVocabResources::MBART50_MANY_TO_MANY,
)), ),
source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY.to_vec(), source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY.to_vec(),
target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY.to_vec(), target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY.to_vec(),
}) })
} }
fn get_m2m100_large_resources( pub(super) fn get_m2m100_large_resources(
&self,
source_languages: Option<&Vec<Language>>, source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>, target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> { ) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages { if let Some(source_languages) = source_languages {
if !source_languages if !source_languages
.iter() .iter()
@ -538,28 +637,19 @@ impl TranslationModelBuilder {
Ok(TranslationResources { Ok(TranslationResources {
model_type: ModelType::M2M100, model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M),
M2M100ModelResources::M2M100_418M, config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M),
)), vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M),
config_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M),
M2M100ConfigResources::M2M100_418M,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
)),
source_languages: M2M100SourceLanguages::M2M100_418M.to_vec(), source_languages: M2M100SourceLanguages::M2M100_418M.to_vec(),
target_languages: M2M100TargetLanguages::M2M100_418M.to_vec(), target_languages: M2M100TargetLanguages::M2M100_418M.to_vec(),
}) })
} }
fn get_m2m100_xlarge_resources( pub(super) fn get_m2m100_xlarge_resources(
&self,
source_languages: Option<&Vec<Language>>, source_languages: Option<&Vec<Language>>,
target_languages: Option<&Vec<Language>>, target_languages: Option<&Vec<Language>>,
) -> Result<TranslationResources, RustBertError> { ) -> Result<TranslationResources<RemoteResource>, RustBertError> {
if let Some(source_languages) = source_languages { if let Some(source_languages) = source_languages {
if !source_languages if !source_languages
.iter() .iter()
@ -588,97 +678,12 @@ impl TranslationModelBuilder {
Ok(TranslationResources { Ok(TranslationResources {
model_type: ModelType::M2M100, model_type: ModelType::M2M100,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_1_2B),
M2M100ModelResources::M2M100_1_2B, config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_1_2B),
)), vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_1_2B),
config_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_1_2B),
M2M100ConfigResources::M2M100_1_2B,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_1_2B,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_1_2B,
)),
source_languages: M2M100SourceLanguages::M2M100_1_2B.to_vec(), source_languages: M2M100SourceLanguages::M2M100_1_2B.to_vec(),
target_languages: M2M100TargetLanguages::M2M100_1_2B.to_vec(), target_languages: M2M100TargetLanguages::M2M100_1_2B.to_vec(),
}) })
} }
/// Creates the translation model based on the specifications provided
///
/// # Returns
/// * `TranslationModel` Generated translation model
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::translation::Language;
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
/// fn main() -> anyhow::Result<()> {
/// let model = TranslationModelBuilder::new()
/// .with_target_languages([
/// Language::Japanese,
/// Language::Korean,
/// Language::ChineseMandarin,
/// ])
/// .create_model();
/// Ok(())
/// }
/// ```
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
let device = self.device.unwrap_or_else(Device::cuda_if_available);
let translation_resources = match (
&self.model_type,
&self.source_languages,
&self.target_languages,
) {
(Some(ModelType::M2M100), source_languages, target_languages) => {
match self.model_size {
Some(value) if value == ModelSize::XLarge => self.get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
_ => self.get_m2m100_large_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
}
}
(Some(ModelType::MBart), source_languages, target_languages) => {
self.get_mbart50_resources(source_languages.as_ref(), target_languages.as_ref())?
}
(Some(ModelType::Marian), source_languages, target_languages) => {
self.get_marian_model(source_languages.as_ref(), target_languages.as_ref())?
}
(None, source_languages, target_languages) => {
self.get_default_model(source_languages.as_ref(), target_languages.as_ref())?
}
(_, None, None) | (_, _, None) | (_, None, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Source and target languages must be specified for {:?}",
self.model_type.unwrap()
)));
}
(Some(model_type), _, _) => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Automated translation model builder not implemented for {:?}",
model_type
)));
}
};
let translation_config = TranslationConfig::new(
translation_resources.model_type,
translation_resources.model_resource,
translation_resources.config_resource,
translation_resources.vocab_resource,
translation_resources.merges_resource,
translation_resources.source_languages,
translation_resources.target_languages,
device,
);
TranslationModel::new(translation_config)
}
} }

View File

@ -14,13 +14,13 @@
use tch::Device; use tch::Device;
use crate::common::error::RustBertError; use crate::common::error::RustBertError;
use crate::common::resources::Resource;
use crate::m2m_100::M2M100Generator; use crate::m2m_100::M2M100Generator;
use crate::marian::MarianGenerator; use crate::marian::MarianGenerator;
use crate::mbart::MBartGenerator; use crate::mbart::MBartGenerator;
use crate::pipelines::common::ModelType; use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator}; use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::resources::ResourceProvider;
use crate::t5::T5Generator; use crate::t5::T5Generator;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::HashSet;
@ -374,13 +374,13 @@ pub struct TranslationConfig {
/// Model type used for translation /// Model type used for translation
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource /// Model weights resource
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource /// Config resource
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource /// Vocab resource
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource /// Merges resource
pub merges_resource: Resource, pub merges_resource: Box<dyn ResourceProvider + Send>,
/// Supported source languages /// Supported source languages
pub source_languages: HashSet<Language>, pub source_languages: HashSet<Language>,
/// Supported target languages /// Supported target languages
@ -435,18 +435,18 @@ impl TranslationConfig {
/// }; /// };
/// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::TranslationConfig; /// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::{RemoteResource, Resource}; /// use rust_bert::resources::RemoteResource;
/// use tch::Device; /// use tch::Device;
/// ///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH, /// MarianModelResources::ROMANCE2ENGLISH,
/// )); /// );
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained( /// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH, /// MarianConfigResources::ROMANCE2ENGLISH,
/// )); /// );
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( /// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH, /// MarianVocabResources::ROMANCE2ENGLISH,
/// )); /// );
/// ///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -464,17 +464,18 @@ impl TranslationConfig {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn new<S, T>( pub fn new<R, S, T>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Resource, merges_resource: R,
source_languages: S, source_languages: S,
target_languages: T, target_languages: T,
device: impl Into<Option<Device>>, device: impl Into<Option<Device>>,
) -> TranslationConfig ) -> TranslationConfig
where where
R: ResourceProvider + Send + 'static,
S: AsRef<[Language]>, S: AsRef<[Language]>,
T: AsRef<[Language]>, T: AsRef<[Language]>,
{ {
@ -482,10 +483,10 @@ impl TranslationConfig {
TranslationConfig { TranslationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: Box::new(merges_resource),
source_languages: source_languages.as_ref().iter().cloned().collect(), source_languages: source_languages.as_ref().iter().cloned().collect(),
target_languages: target_languages.as_ref().iter().cloned().collect(), target_languages: target_languages.as_ref().iter().cloned().collect(),
device, device,
@ -798,18 +799,18 @@ impl TranslationModel {
/// }; /// };
/// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; /// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource}; /// use rust_bert::resources::RemoteResource;
/// use tch::Device; /// use tch::Device;
/// ///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH, /// MarianModelResources::ROMANCE2ENGLISH,
/// )); /// );
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained( /// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ROMANCE2ENGLISH, /// MarianConfigResources::ROMANCE2ENGLISH,
/// )); /// );
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( /// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ROMANCE2ENGLISH, /// MarianVocabResources::ROMANCE2ENGLISH,
/// )); /// );
/// ///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -859,21 +860,21 @@ impl TranslationModel {
/// }; /// };
/// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource}; /// use rust_bert::resources::RemoteResource;
/// use tch::Device; /// use tch::Device;
/// ///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// let model_resource = RemoteResource::from_pretrained(
/// MarianModelResources::ENGLISH2ROMANCE, /// MarianModelResources::ENGLISH2ROMANCE,
/// )); /// );
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained( /// let config_resource = RemoteResource::from_pretrained(
/// MarianConfigResources::ENGLISH2ROMANCE, /// MarianConfigResources::ENGLISH2ROMANCE,
/// )); /// );
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( /// let vocab_resource = RemoteResource::from_pretrained(
/// MarianVocabResources::ENGLISH2ROMANCE, /// MarianVocabResources::ENGLISH2ROMANCE,
/// )); /// );
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained( /// let merges_resource = RemoteResource::from_pretrained(
/// MarianSpmResources::ENGLISH2ROMANCE, /// MarianSpmResources::ENGLISH2ROMANCE,
/// )); /// );
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE; /// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE; /// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
/// ///
@ -938,15 +939,10 @@ mod test {
#[test] #[test]
#[ignore] // no need to run, compilation is enough to verify it is Send #[ignore] // no need to run, compilation is enough to verify it is Send
fn test() { fn test() {
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
MarianModelResources::ROMANCE2ENGLISH, let config_resource =
)); RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
MarianConfigResources::ROMANCE2ENGLISH,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ROMANCE2ENGLISH,
));
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;

View File

@ -99,10 +99,7 @@
//! ``` //! ```
use crate::albert::AlbertForSequenceClassification; use crate::albert::AlbertForSequenceClassification;
use crate::bart::{ use crate::bart::BartForSequenceClassification;
BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources,
BartVocabResources,
};
use crate::bert::BertForSequenceClassification; use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification; use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier; use crate::distilbert::DistilBertModelClassifier;
@ -110,7 +107,7 @@ use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification; use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::sequence_classification::Label; use crate::pipelines::sequence_classification::Label;
use crate::resources::{RemoteResource, Resource}; use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification; use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification; use crate::xlnet::XLNetForSequenceClassification;
use crate::RustBertError; use crate::RustBertError;
@ -122,19 +119,25 @@ use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor}; use tch::{nn, no_grad, Device, Tensor};
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
/// # Configuration for ZeroShotClassificationModel /// # Configuration for ZeroShotClassificationModel
/// Contains information regarding the model to load and device to place the model on. /// Contains information regarding the model to load and device to place the model on.
pub struct ZeroShotClassificationConfig { pub struct ZeroShotClassificationConfig {
/// Model type /// Model type
pub model_type: ModelType, pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL) /// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource, pub model_resource: Box<dyn ResourceProvider + Send>,
/// Config resource (default: pretrained BERT model on CoNLL) /// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource, pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL) /// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: None) /// Merges resource (default: None)
pub merges_resource: Option<Resource>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Automatically lower case all input upon tokenization (assumes a lower-cased model) /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
pub lower_case: bool, pub lower_case: bool,
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
@ -151,27 +154,30 @@ impl ZeroShotClassificationConfig {
/// # Arguments /// # Arguments
/// ///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!) /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model - The `Resource` pointing to the model to load (e.g. model.ot) /// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json) /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new( pub fn new<R>(
model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: R,
config_resource: Resource, config_resource: R,
vocab_resource: Resource, vocab_resource: R,
merges_resource: Option<Resource>, merges_resource: Option<R>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig { ) -> ZeroShotClassificationConfig
where
R: ResourceProvider + Send + 'static,
{
ZeroShotClassificationConfig { ZeroShotClassificationConfig {
model_type, model_type,
model_resource, model_resource: Box::new(model_resource),
config_resource, config_resource: Box::new(config_resource),
vocab_resource, vocab_resource: Box::new(vocab_resource),
merges_resource, merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case, lower_case,
strip_accents: strip_accents.into(), strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(), add_prefix_space: add_prefix_space.into(),
@ -180,21 +186,22 @@ impl ZeroShotClassificationConfig {
} }
} }
#[cfg(feature = "remote")]
impl Default for ZeroShotClassificationConfig { impl Default for ZeroShotClassificationConfig {
/// Provides a defaultSST-2 sentiment analysis model (English) /// Provides a defaultSST-2 sentiment analysis model (English)
fn default() -> ZeroShotClassificationConfig { fn default() -> ZeroShotClassificationConfig {
ZeroShotClassificationConfig { ZeroShotClassificationConfig {
model_type: ModelType::Bart, model_type: ModelType::Bart,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_MNLI, BartModelResources::BART_MNLI,
)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(
BartConfigResources::BART_MNLI, BartConfigResources::BART_MNLI,
)), )),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
BartVocabResources::BART_MNLI, BartVocabResources::BART_MNLI,
)), )),
merges_resource: Some(Resource::Remote(RemoteResource::from_pretrained( merges_resource: Some(Box::new(RemoteResource::from_pretrained(
BartMergesResources::BART_MNLI, BartMergesResources::BART_MNLI,
))), ))),
lower_case: false, lower_case: false,

View File

@ -20,17 +20,17 @@
//! use rust_bert::prophetnet::{ //! use rust_bert::prophetnet::{
//! ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources, //! ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
//! }; //! };
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::RemoteResource;
//! use tch::Device; //! use tch::Device;
//! //!
//! fn main() -> anyhow::Result<()> { //! fn main() -> anyhow::Result<()> {
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let config_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM, //! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
//! )); //! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( //! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM, //! ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
//! )); //! ));
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained( //! let weights_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM, //! ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
//! )); //! ));
//! //!

View File

@ -18,8 +18,6 @@ use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tch::{nn, Kind, Tensor}; use tch::{nn, Kind, Tensor};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator, PreparedInput, PrivateLanguageGenerator,
@ -909,40 +907,9 @@ impl ProphetNetConditionalGenerator {
pub fn new( pub fn new(
generate_config: GenerateConfig, generate_config: GenerateConfig,
) -> Result<ProphetNetConditionalGenerator, RustBertError> { ) -> Result<ProphetNetConditionalGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let weights_path = generate_config.model_resource.get_local_path()?;
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();

View File

@ -19,19 +19,19 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::reformer::{ReformerConfig, ReformerModel}; //! use rust_bert::reformer::{ReformerConfig, ReformerModel};
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::ReformerTokenizer; //! use rust_tokenizers::tokenizer::ReformerTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/weights.ot"), //! local_path: PathBuf::from("path/to/weights.ot"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;

View File

@ -23,8 +23,6 @@ use tch::{nn, Device, Kind, Tensor};
use crate::common::activations::Activation; use crate::common::activations::Activation;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair; use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator, PreparedInput, PrivateLanguageGenerator,
@ -1019,40 +1017,9 @@ pub struct ReformerGenerator {
impl ReformerGenerator { impl ReformerGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<ReformerGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<ReformerGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let weights_path = generate_config.model_resource.get_local_path()?;
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();

View File

@ -23,23 +23,23 @@
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig; //! use rust_bert::bert::BertConfig;
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::roberta::RobertaForMaskedLM; //! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::RobertaTokenizer; //! use rust_tokenizers::tokenizer::RobertaTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let vocab_resource = Resource::Local(LocalResource { //! let vocab_resource = LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"), //! local_path: PathBuf::from("path/to/vocab.txt"),
//! }); //! };
//! let merges_resource = Resource::Local(LocalResource { //! let merges_resource = LocalResource {
//! local_path: PathBuf::from("path/to/merges.txt"), //! local_path: PathBuf::from("path/to/merges.txt"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?; //! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?; //! let merges_path = merges_resource.get_local_path()?;

View File

@ -19,20 +19,20 @@
//! # //! #
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//! # use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, Resource}; //! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::t5::{T5Config, T5ForConditionalGeneration}; //! use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::T5Tokenizer; //! use rust_tokenizers::tokenizer::T5Tokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { //! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"), //! local_path: PathBuf::from("path/to/config.json"),
//! }); //! };
//! let sentence_piece_resource = Resource::Local(LocalResource { //! let sentence_piece_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"), //! local_path: PathBuf::from("path/to/spiece.model"),
//! }); //! };
//! let weights_resource = Resource::Local(LocalResource { //! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"), //! local_path: PathBuf::from("path/to/model.ot"),
//! }); //! };
//! let config_path = config_resource.get_local_path()?; //! let config_path = config_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?; //! let spiece_path = sentence_piece_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?; //! let weights_path = weights_resource.get_local_path()?;

View File

@ -18,8 +18,6 @@ use serde::{Deserialize, Serialize};
use tch::nn::embedding; use tch::nn::embedding;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{ use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator, PreparedInput, PrivateLanguageGenerator,
@ -715,34 +713,9 @@ pub struct T5Generator {
impl T5Generator { impl T5Generator {
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models let config_path = generate_config.config_resource.get_local_path()?;
let model_resource = if generate_config.model_resource let vocab_path = generate_config.vocab_resource.get_local_path()?;
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)) let weights_path = generate_config.model_resource.get_local_path()?;
{
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device; let device = generate_config.device;
generate_config.validate(); generate_config.validate();

View File

@ -22,18 +22,18 @@
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::generation_utils::LanguageGenerator; //! use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; //! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, Resource}; //! use rust_bert::resources::RemoteResource;
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources}; //! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained( //! let config_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetConfigResources::XLNET_BASE_CASED, //! XLNetConfigResources::XLNET_BASE_CASED,
//! )); //! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( //! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED, //! XLNetVocabResources::XLNET_BASE_CASED,
//! )); //! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained( //! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED, //! XLNetVocabResources::XLNET_BASE_CASED,
//! )); //! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained( //! let model_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetModelResources::XLNET_BASE_CASED, //! XLNetModelResources::XLNET_BASE_CASED,
//! )); //! ));
//! let generate_config = TextGenerationConfig { //! let generate_config = TextGenerationConfig {

View File

@ -6,7 +6,7 @@ use rust_bert::albert::{
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources, AlbertModelResources, AlbertVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -16,13 +16,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn albert_masked_lm() -> anyhow::Result<()> { fn albert_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2, AlbertConfigResources::ALBERT_BASE_V2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2, AlbertVocabResources::ALBERT_BASE_V2,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2, AlbertModelResources::ALBERT_BASE_V2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -87,10 +87,10 @@ fn albert_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn albert_for_sequence_classification() -> anyhow::Result<()> { fn albert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2, AlbertConfigResources::ALBERT_BASE_V2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2, AlbertVocabResources::ALBERT_BASE_V2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -153,10 +153,10 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn albert_for_multiple_choice() -> anyhow::Result<()> { fn albert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2, AlbertConfigResources::ALBERT_BASE_V2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2, AlbertVocabResources::ALBERT_BASE_V2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -219,10 +219,10 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn albert_for_token_classification() -> anyhow::Result<()> { fn albert_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2, AlbertConfigResources::ALBERT_BASE_V2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2, AlbertVocabResources::ALBERT_BASE_V2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -286,10 +286,10 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn albert_for_question_answering() -> anyhow::Result<()> { fn albert_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2, AlbertConfigResources::ALBERT_BASE_V2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2, AlbertVocabResources::ALBERT_BASE_V2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -6,7 +6,7 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::pipelines::zero_shot_classification::{ use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel, ZeroShotClassificationConfig, ZeroShotClassificationModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn bart_lm_model() -> anyhow::Result<()> { fn bart_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6, BartConfigResources::DISTILBART_CNN_6_6,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6, BartVocabResources::DISTILBART_CNN_6_6,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6, BartMergesResources::DISTILBART_CNN_6_6,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6, BartModelResources::DISTILBART_CNN_6_6,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -77,16 +77,16 @@ fn bart_lm_model() -> anyhow::Result<()> {
#[test] #[test]
fn bart_summarization_greedy() -> anyhow::Result<()> { fn bart_summarization_greedy() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6, BartConfigResources::DISTILBART_CNN_6_6,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6, BartVocabResources::DISTILBART_CNN_6_6,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6, BartMergesResources::DISTILBART_CNN_6_6,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6, BartModelResources::DISTILBART_CNN_6_6,
)); ));
let summarization_config = SummarizationConfig { let summarization_config = SummarizationConfig {
@ -138,16 +138,16 @@ about exoplanets like K2-18b."];
#[test] #[test]
fn bart_summarization_beam_search() -> anyhow::Result<()> { fn bart_summarization_beam_search() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6, BartConfigResources::DISTILBART_CNN_6_6,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6, BartVocabResources::DISTILBART_CNN_6_6,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6, BartMergesResources::DISTILBART_CNN_6_6,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6, BartModelResources::DISTILBART_CNN_6_6,
)); ));
let summarization_config = SummarizationConfig { let summarization_config = SummarizationConfig {

View File

@ -11,7 +11,7 @@ use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -21,12 +21,9 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn bert_masked_lm() -> anyhow::Result<()> { fn bert_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource = let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?; let weights_path = weights_resource.get_local_path()?;
@ -106,10 +103,8 @@ fn bert_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn bert_for_sequence_classification() -> anyhow::Result<()> { fn bert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -170,10 +165,8 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn bert_for_multiple_choice() -> anyhow::Result<()> { fn bert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -230,10 +223,8 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn bert_for_token_classification() -> anyhow::Result<()> { fn bert_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -295,10 +286,8 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn bert_for_question_answering() -> anyhow::Result<()> { fn bert_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -422,11 +411,9 @@ fn bert_question_answering() -> anyhow::Result<()> {
// Set-up question answering model // Set-up question answering model
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)), RemoteResource::from_pretrained(BertModelResources::BERT_QA),
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
BertConfigResources::BERT_QA, RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta None, //merges resource only relevant with ModelType::Roberta
false, false,
false, false,

View File

@ -3,7 +3,7 @@ use rust_bert::deberta::{
DebertaForSequenceClassification, DebertaForTokenClassification, DebertaMergesResources, DebertaForSequenceClassification, DebertaForTokenClassification, DebertaMergesResources,
DebertaModelResources, DebertaVocabResources, DebertaModelResources, DebertaVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use std::collections::HashMap; use std::collections::HashMap;
@ -14,16 +14,16 @@ extern crate anyhow;
#[test] #[test]
fn deberta_natural_language_inference() -> anyhow::Result<()> { fn deberta_natural_language_inference() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI, DebertaConfigResources::DEBERTA_BASE_MNLI,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI, DebertaVocabResources::DEBERTA_BASE_MNLI,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI, DebertaMergesResources::DEBERTA_BASE_MNLI,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
DebertaModelResources::DEBERTA_BASE_MNLI, DebertaModelResources::DEBERTA_BASE_MNLI,
)); ));
@ -87,7 +87,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
#[test] #[test]
fn deberta_masked_lm() -> anyhow::Result<()> { fn deberta_masked_lm() -> anyhow::Result<()> {
// Set-up masked LM model // Set-up masked LM model
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI, DebertaConfigResources::DEBERTA_BASE_MNLI,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -142,13 +142,13 @@ fn deberta_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn deberta_for_token_classification() -> anyhow::Result<()> { fn deberta_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI, DebertaConfigResources::DEBERTA_BASE_MNLI,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI, DebertaVocabResources::DEBERTA_BASE_MNLI,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI, DebertaMergesResources::DEBERTA_BASE_MNLI,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -203,13 +203,13 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn deberta_for_question_answering() -> anyhow::Result<()> { fn deberta_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI, DebertaConfigResources::DEBERTA_BASE_MNLI,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI, DebertaVocabResources::DEBERTA_BASE_MNLI,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI, DebertaMergesResources::DEBERTA_BASE_MNLI,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::distilbert::{
}; };
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel}; use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity}; use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -42,13 +42,13 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
#[test] #[test]
fn distilbert_masked_lm() -> anyhow::Result<()> { fn distilbert_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT, DistilBertConfigResources::DISTIL_BERT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT, DistilBertVocabResources::DISTIL_BERT,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT, DistilBertModelResources::DISTIL_BERT,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -123,10 +123,10 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn distilbert_for_question_answering() -> anyhow::Result<()> { fn distilbert_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD, DistilBertConfigResources::DISTIL_BERT_SQUAD,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD, DistilBertVocabResources::DISTIL_BERT_SQUAD,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -188,10 +188,10 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
#[test] #[test]
fn distilbert_for_token_classification() -> anyhow::Result<()> { fn distilbert_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT, DistilBertConfigResources::DISTIL_BERT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT, DistilBertVocabResources::DISTIL_BERT,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -3,7 +3,7 @@ use rust_bert::gpt2::{
Gpt2VocabResources, Gpt2VocabResources,
}; };
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -11,16 +11,16 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn distilgpt2_lm_model() -> anyhow::Result<()> { fn distilgpt2_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2, Gpt2ConfigResources::DISTIL_GPT2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2, Gpt2VocabResources::DISTIL_GPT2,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2, Gpt2MergesResources::DISTIL_GPT2,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2, Gpt2ModelResources::DISTIL_GPT2,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -2,7 +2,7 @@ use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM, ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
ElectraModelResources, ElectraVocabResources, ElectraModelResources, ElectraVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -11,13 +11,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn electra_masked_lm() -> anyhow::Result<()> { fn electra_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR, ElectraConfigResources::BASE_GENERATOR,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR, ElectraVocabResources::BASE_GENERATOR,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR, ElectraModelResources::BASE_GENERATOR,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -95,13 +95,13 @@ fn electra_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn electra_discriminator() -> anyhow::Result<()> { fn electra_discriminator() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR, ElectraConfigResources::BASE_DISCRIMINATOR,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR, ElectraVocabResources::BASE_DISCRIMINATOR,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR, ElectraModelResources::BASE_DISCRIMINATOR,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -7,7 +7,7 @@ use rust_bert::fnet::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity}; use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -17,12 +17,9 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn fnet_masked_lm() -> anyhow::Result<()> { fn fnet_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE)); let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let vocab_resource = let weights_resource = Box::new(RemoteResource::from_pretrained(FNetModelResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetModelResources::BASE));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?; let weights_path = weights_resource.get_local_path()?;
@ -85,13 +82,13 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn fnet_for_sequence_classification() -> anyhow::Result<()> { fn fnet_for_sequence_classification() -> anyhow::Result<()> {
// Set up classifier // Set up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2, FNetConfigResources::BASE_SST2,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2, FNetVocabResources::BASE_SST2,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2, FNetModelResources::BASE_SST2,
)); ));
@ -128,10 +125,8 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn fnet_for_multiple_choice() -> anyhow::Result<()> { fn fnet_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE)); let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -188,10 +183,8 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn fnet_for_token_classification() -> anyhow::Result<()> { fn fnet_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE)); let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
@ -251,10 +244,8 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn fnet_for_question_answering() -> anyhow::Result<()> { fn fnet_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE)); let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;

View File

@ -10,7 +10,7 @@ use rust_bert::pipelines::generation_utils::{
Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator, Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator,
}; };
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -18,14 +18,10 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn gpt2_lm_model() -> anyhow::Result<()> { fn gpt2_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = let config_resource = RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2);
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = RemoteResource::from_pretrained(Gpt2VocabResources::GPT2);
let vocab_resource = let merges_resource = RemoteResource::from_pretrained(Gpt2MergesResources::GPT2);
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let weights_resource = RemoteResource::from_pretrained(Gpt2ModelResources::GPT2);
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -114,14 +110,10 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_generation_greedy() -> anyhow::Result<()> { fn gpt2_generation_greedy() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
@ -150,14 +142,10 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_generation_beam_search() -> anyhow::Result<()> { fn gpt2_generation_beam_search() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
@ -198,14 +186,10 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> { fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
@ -259,14 +243,10 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
#[test] #[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> { fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
@ -319,14 +299,10 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
#[test] #[test]
fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> { fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
@ -381,14 +357,10 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
#[test] #[test]
fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> { fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628]; let paragraph_tokens = [198, 628];
@ -450,14 +422,10 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> { fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 36, max_length: 36,
@ -520,14 +488,10 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> { fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 36, max_length: 36,
@ -590,14 +554,10 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
#[test] #[test]
fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> { fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
// Resources definition // Resources definition
let config_resource = let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let vocab_resource = let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
let paragraph_tokens = [198, 628]; let paragraph_tokens = [198, 628];

View File

@ -4,7 +4,7 @@ use rust_bert::gpt_neo::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -12,16 +12,16 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn gpt_neo_lm() -> anyhow::Result<()> { fn gpt_neo_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M, GptNeoConfigResources::GPT_NEO_125M,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M, GptNeoVocabResources::GPT_NEO_125M,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M, GptNeoMergesResources::GPT_NEO_125M,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M, GptNeoModelResources::GPT_NEO_125M,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -109,16 +109,16 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
#[test] #[test]
fn test_generation_gpt_neo() -> anyhow::Result<()> { fn test_generation_gpt_neo() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M, GptNeoConfigResources::GPT_NEO_125M,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M, GptNeoVocabResources::GPT_NEO_125M,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M, GptNeoMergesResources::GPT_NEO_125M,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M, GptNeoModelResources::GPT_NEO_125M,
)); ));

View File

@ -11,7 +11,7 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{RobertaVocab, Vocab}; use rust_tokenizers::vocab::{RobertaVocab, Vocab};
@ -21,18 +21,14 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn longformer_masked_lm() -> anyhow::Result<()> { fn longformer_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
LongformerConfigResources::LONGFORMER_BASE_4096, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
)); let vocab_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
LongformerVocabResources::LONGFORMER_BASE_4096, let merges_resource =
)); RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource =
LongformerMergesResources::LONGFORMER_BASE_4096, RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_4096);
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -176,15 +172,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn longformer_for_sequence_classification() -> anyhow::Result<()> { fn longformer_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
LongformerConfigResources::LONGFORMER_BASE_4096, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
)); let vocab_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
LongformerVocabResources::LONGFORMER_BASE_4096, let merges_resource =
)); RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -245,15 +238,12 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn longformer_for_multiple_choice() -> anyhow::Result<()> { fn longformer_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
LongformerConfigResources::LONGFORMER_BASE_4096, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
)); let vocab_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
LongformerVocabResources::LONGFORMER_BASE_4096, let merges_resource =
)); RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -321,15 +311,12 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> { fn mobilebert_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
LongformerConfigResources::LONGFORMER_BASE_4096, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
)); let vocab_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
LongformerVocabResources::LONGFORMER_BASE_4096, let merges_resource =
)); RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -394,18 +381,12 @@ fn longformer_for_question_answering() -> anyhow::Result<()> {
// Set-up Question Answering model // Set-up Question Answering model
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Longformer, ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
LongformerModelResources::LONGFORMER_BASE_SQUAD1, RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
)), RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Resource::Remote(RemoteResource::from_pretrained( Some(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1, LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))), )),
false, false,
None, None,
false, false,

View File

@ -4,7 +4,7 @@ use rust_bert::m2m_100::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -12,18 +12,10 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn m2m100_lm_model() -> anyhow::Result<()> { fn m2m100_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
M2M100ConfigResources::M2M100_418M, let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
)); let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -76,18 +68,10 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
#[test] #[test]
fn m2m100_translation() -> anyhow::Result<()> { fn m2m100_translation() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
M2M100ModelResources::M2M100_418M, let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
)); let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let source_languages = M2M100SourceLanguages::M2M100_418M; let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M; let target_languages = M2M100TargetLanguages::M2M100_418M;

View File

@ -6,25 +6,17 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{ use rust_bert::pipelines::translation::{
Language, TranslationConfig, TranslationModel, TranslationModelBuilder, Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
#[test] #[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)] // #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> anyhow::Result<()> { fn test_translation() -> anyhow::Result<()> {
// Set-up translation model // Set-up translation model
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
MarianModelResources::ENGLISH2ROMANCE, let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
)); let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
MarianConfigResources::ENGLISH2ROMANCE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2ROMANCE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2ROMANCE,
));
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE; let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE; let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;

View File

@ -3,7 +3,7 @@ use rust_bert::mbart::{
}; };
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder}; use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -11,13 +11,13 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn mbart_lm_model() -> anyhow::Result<()> { fn mbart_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY, MBartConfigResources::MBART50_MANY_TO_MANY,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY, MBartVocabResources::MBART50_MANY_TO_MANY,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY, MBartModelResources::MBART50_MANY_TO_MANY,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::mobilebert::{
MobileBertModelResources, MobileBertVocabResources, MobileBertModelResources, MobileBertVocabResources,
}; };
use rust_bert::pipelines::pos_tagging::POSModel; use rust_bert::pipelines::pos_tagging::POSModel;
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn mobilebert_masked_model() -> anyhow::Result<()> { fn mobilebert_masked_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED, MobileBertConfigResources::MOBILEBERT_UNCASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED, MobileBertVocabResources::MOBILEBERT_UNCASED,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_UNCASED, MobileBertModelResources::MOBILEBERT_UNCASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -111,10 +111,10 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
#[test] #[test]
fn mobilebert_for_sequence_classification() -> anyhow::Result<()> { fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED, MobileBertConfigResources::MOBILEBERT_UNCASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED, MobileBertVocabResources::MOBILEBERT_UNCASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -162,10 +162,10 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn mobilebert_for_multiple_choice() -> anyhow::Result<()> { fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED, MobileBertConfigResources::MOBILEBERT_UNCASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED, MobileBertVocabResources::MOBILEBERT_UNCASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -220,10 +220,10 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> { fn mobilebert_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED, MobileBertConfigResources::MOBILEBERT_UNCASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED, MobileBertVocabResources::MOBILEBERT_UNCASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -273,10 +273,10 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn mobilebert_for_question_answering() -> anyhow::Result<()> { fn mobilebert_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED, MobileBertConfigResources::MOBILEBERT_UNCASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED, MobileBertVocabResources::MOBILEBERT_UNCASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -6,7 +6,7 @@ use rust_bert::openai_gpt::{
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor}; use tch::{nn, Device, Tensor};
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
#[test] #[test]
fn openai_gpt_lm_model() -> anyhow::Result<()> { fn openai_gpt_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT, OpenAiGptConfigResources::GPT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT, OpenAiGptVocabResources::GPT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT, OpenAiGptMergesResources::GPT,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT, OpenAiGptModelResources::GPT,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -104,16 +104,16 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
#[test] #[test]
fn openai_gpt_generation_greedy() -> anyhow::Result<()> { fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT, OpenAiGptConfigResources::GPT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT, OpenAiGptVocabResources::GPT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT, OpenAiGptMergesResources::GPT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT, OpenAiGptModelResources::GPT,
)); ));
@ -146,16 +146,16 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
#[test] #[test]
fn openai_gpt_generation_beam_search() -> anyhow::Result<()> { fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT, OpenAiGptConfigResources::GPT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT, OpenAiGptVocabResources::GPT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT, OpenAiGptMergesResources::GPT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT, OpenAiGptModelResources::GPT,
)); ));
@ -199,16 +199,16 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
#[test] #[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> { fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT, OpenAiGptConfigResources::GPT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT, OpenAiGptVocabResources::GPT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT, OpenAiGptMergesResources::GPT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT, OpenAiGptModelResources::GPT,
)); ));
@ -268,16 +268,16 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
#[test] #[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> { fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT, OpenAiGptConfigResources::GPT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT, OpenAiGptVocabResources::GPT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT, OpenAiGptMergesResources::GPT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT, OpenAiGptModelResources::GPT,
)); ));

View File

@ -2,19 +2,19 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources}; use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
#[test] #[test]
fn pegasus_summarization_greedy() -> anyhow::Result<()> { fn pegasus_summarization_greedy() -> anyhow::Result<()> {
// Set-up model // Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL, PegasusConfigResources::CNN_DAILYMAIL,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL, PegasusVocabResources::CNN_DAILYMAIL,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL, PegasusModelResources::CNN_DAILYMAIL,
)); ));

View File

@ -4,19 +4,19 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::prophetnet::{ use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
}; };
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use tch::Device; use tch::Device;
#[test] #[test]
fn prophetnet_summarization_greedy() -> anyhow::Result<()> { fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
// Set-up model // Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM, ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM, ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM, ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
)); ));

View File

@ -4,7 +4,7 @@ use rust_bert::reformer::{
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering, ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
ReformerForSequenceClassification, ReformerModelResources, ReformerVocabResources, ReformerForSequenceClassification, ReformerModelResources, ReformerVocabResources,
}; };
use rust_bert::resources::{LocalResource, RemoteResource, Resource}; use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
use std::collections::HashMap; use std::collections::HashMap;
@ -17,7 +17,7 @@ use tch::{nn, no_grad, Device, Tensor};
fn test_generation_reformer() -> anyhow::Result<()> { fn test_generation_reformer() -> anyhow::Result<()> {
// =================================================== // ===================================================
// Modify resource to enforce seed // Modify resource to enforce seed
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT, ReformerConfigResources::CRIME_AND_PUNISHMENT,
)); ));
@ -31,18 +31,18 @@ fn test_generation_reformer() -> anyhow::Result<()> {
let _ = updated_config_file.write_all(serde_json::to_string(&config).unwrap().as_bytes()); let _ = updated_config_file.write_all(serde_json::to_string(&config).unwrap().as_bytes());
let updated_config_path = updated_config_file.into_temp_path(); let updated_config_path = updated_config_file.into_temp_path();
let config_resource = Resource::Local(LocalResource { let config_resource = Box::new(LocalResource {
local_path: updated_config_path.to_path_buf(), local_path: updated_config_path.to_path_buf(),
}); });
// =================================================== // ===================================================
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT, ReformerModelResources::CRIME_AND_PUNISHMENT,
)); ));
// Set-up translation model // Set-up translation model
@ -79,10 +79,10 @@ fn test_generation_reformer() -> anyhow::Result<()> {
#[test] #[test]
fn reformer_for_sequence_classification() -> anyhow::Result<()> { fn reformer_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT, ReformerConfigResources::CRIME_AND_PUNISHMENT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -145,10 +145,10 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn reformer_for_question_answering() -> anyhow::Result<()> { fn reformer_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT, ReformerConfigResources::CRIME_AND_PUNISHMENT,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;

View File

@ -5,7 +5,7 @@ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
}; };
use rust_bert::pipelines::token_classification::TokenClassificationConfig; use rust_bert::pipelines::token_classification::TokenClassificationConfig;
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::roberta::{ use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources,
@ -20,18 +20,13 @@ use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn roberta_masked_lm() -> anyhow::Result<()> { fn roberta_masked_lm() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
RobertaConfigResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
)); let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource =
RobertaVocabResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
)); let weights_resource =
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(RobertaModelResources::DISTILROBERTA_BASE);
RobertaMergesResources::DISTILROBERTA_BASE,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::DISTILROBERTA_BASE,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -116,15 +111,11 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
#[test] #[test]
fn roberta_for_sequence_classification() -> anyhow::Result<()> { fn roberta_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
RobertaConfigResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
)); let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource =
RobertaVocabResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -190,15 +181,11 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn roberta_for_multiple_choice() -> anyhow::Result<()> { fn roberta_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
RobertaConfigResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
)); let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource =
RobertaVocabResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -260,15 +247,11 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn roberta_for_token_classification() -> anyhow::Result<()> { fn roberta_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource =
RobertaConfigResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
)); let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource =
RobertaVocabResources::DISTILROBERTA_BASE, RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::DISTILROBERTA_BASE,
));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?;
@ -337,18 +320,12 @@ fn roberta_question_answering() -> anyhow::Result<()> {
// Set-up question answering model // Set-up question answering model
let config = QuestionAnsweringConfig::new( let config = QuestionAnsweringConfig::new(
ModelType::Roberta, ModelType::Roberta,
Resource::Remote(RemoteResource::from_pretrained( RemoteResource::from_pretrained(RobertaModelResources::ROBERTA_QA),
RobertaModelResources::ROBERTA_QA, RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA_QA),
)), RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA_QA),
Resource::Remote(RemoteResource::from_pretrained( Some(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA, RobertaMergesResources::ROBERTA_QA,
))), )),
false, false,
None, None,
false, false,
@ -378,13 +355,13 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
// Set-up question answering model // Set-up question answering model
let ner_config = TokenClassificationConfig { let ner_config = TokenClassificationConfig {
model_type: ModelType::XLMRoberta, model_type: ModelType::XLMRoberta,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE, RobertaModelResources::XLM_ROBERTA_NER_DE,
)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained( config_resource: Box::new(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE, RobertaConfigResources::XLM_ROBERTA_NER_DE,
)), )),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
RobertaVocabResources::XLM_ROBERTA_NER_DE, RobertaVocabResources::XLM_ROBERTA_NER_DE,
)), )),
lower_case: false, lower_case: false,

View File

@ -1,20 +1,16 @@
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device; use tch::Device;
#[test] #[test]
fn test_translation_t5() -> anyhow::Result<()> { fn test_translation_t5() -> anyhow::Result<()> {
let model_resource = let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
let config_resource = let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let source_languages = [ let source_languages = [
Language::English, Language::English,
@ -70,18 +66,10 @@ fn test_summarization_t5() -> anyhow::Result<()> {
// Set-up translation model // Set-up translation model
let summarization_config = SummarizationConfig { let summarization_config = SummarizationConfig {
model_type: ModelType::T5, model_type: ModelType::T5,
model_resource: Resource::Remote(RemoteResource::from_pretrained( model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
T5ModelResources::T5_SMALL, config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
)), vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
config_resource: Resource::Remote(RemoteResource::from_pretrained( merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
T5ConfigResources::T5_SMALL,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
T5VocabResources::T5_SMALL,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
T5VocabResources::T5_SMALL,
)),
min_length: 30, min_length: 30,
max_length: 200, max_length: 200,
early_stopping: true, early_stopping: true,

View File

@ -1,6 +1,6 @@
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::xlnet::{ use rust_bert::xlnet::{
XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering, XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel, XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel,
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Kind, Tensor};
#[test] #[test]
fn xlnet_base_model() -> anyhow::Result<()> { fn xlnet_base_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -122,13 +122,13 @@ fn xlnet_base_model() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_lm_model() -> anyhow::Result<()> { fn xlnet_lm_model() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained( let weights_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -196,16 +196,16 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_generation_beam_search() -> anyhow::Result<()> { fn xlnet_generation_beam_search() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained( let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));
@ -239,10 +239,10 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_for_sequence_classification() -> anyhow::Result<()> { fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -311,10 +311,10 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_for_multiple_choice() -> anyhow::Result<()> { fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -379,10 +379,10 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_for_token_classification() -> anyhow::Result<()> { fn xlnet_for_token_classification() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;
@ -442,10 +442,10 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
#[test] #[test]
fn xlnet_for_question_answering() -> anyhow::Result<()> { fn xlnet_for_question_answering() -> anyhow::Result<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained( let config_resource = Box::new(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED, XLNetConfigResources::XLNET_BASE_CASED,
)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let config_path = config_resource.get_local_path()?; let config_path = config_resource.get_local_path()?;