mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Refactor: Feature gate remote resource (#223)
* get_local_path as trait LocalPathProvider * Remove config default impls * Feature gate RemoteResource * translation_builder refactoring to have remote fetching grouped * Include dirs crate in remote feature gate * Examples fixes * Benches fixes * Tests fix * Remove Box from constructor parameters * Fix examples no-Box * Fix benches no-Box * Fix tests no-Box * Fix doc comment code * Fix documentation `Resource` -> `ResourceProvider` * moved remote local at same level * moved ResourceProvider to resources mod Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
parent
23c5d9112a
commit
9b22c2482a
@ -50,8 +50,10 @@ harness = false
|
|||||||
opt-level = 3
|
opt-level = 3
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
default = ["remote"]
|
||||||
doc-only = ["tch/doc-only"]
|
doc-only = ["tch/doc-only"]
|
||||||
all-tests = []
|
all-tests = []
|
||||||
|
remote = [ "cached-path", "dirs", "lazy_static" ]
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
features = ["doc-only"]
|
features = ["doc-only"]
|
||||||
@ -61,14 +63,15 @@ rust_tokenizers = "~7.0.1"
|
|||||||
tch = "~0.6.1"
|
tch = "~0.6.1"
|
||||||
serde_json = "1.0.73"
|
serde_json = "1.0.73"
|
||||||
serde = { version = "1.0.132", features = ["derive"] }
|
serde = { version = "1.0.132", features = ["derive"] }
|
||||||
dirs = "4.0.0"
|
|
||||||
ordered-float = "2.8.0"
|
ordered-float = "2.8.0"
|
||||||
cached-path = "0.5.1"
|
|
||||||
lazy_static = "1.4.0"
|
|
||||||
uuid = { version = "0.8.2", features = ["v4"] }
|
uuid = { version = "0.8.2", features = ["v4"] }
|
||||||
thiserror = "1.0.30"
|
thiserror = "1.0.30"
|
||||||
half = "1.8.2"
|
half = "1.8.2"
|
||||||
|
|
||||||
|
cached-path = { version = "0.5.1", optional = true }
|
||||||
|
dirs = { version = "4.0.0", optional = true }
|
||||||
|
lazy_static = { version = "1.4.0", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1.0.51"
|
anyhow = "1.0.51"
|
||||||
csv = "1.1.6"
|
csv = "1.1.6"
|
||||||
|
@ -7,21 +7,17 @@ use rust_bert::gpt2::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn create_text_generation_model() -> TextGenerationModel {
|
fn create_text_generation_model() -> TextGenerationModel {
|
||||||
let config = TextGenerationConfig {
|
let config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
|
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
|
||||||
Gpt2ConfigResources::GPT2,
|
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
|
||||||
)),
|
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2MergesResources::GPT2,
|
|
||||||
)),
|
|
||||||
min_length: 0,
|
min_length: 0,
|
||||||
max_length: 30,
|
max_length: 30,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
|
@ -7,7 +7,7 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::pipelines::question_answering::{
|
use rust_bert::pipelines::question_answering::{
|
||||||
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@ -17,11 +17,9 @@ static BATCH_SIZE: usize = 64;
|
|||||||
fn create_qa_model() -> QuestionAnsweringModel {
|
fn create_qa_model() -> QuestionAnsweringModel {
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
|
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||||
BertConfigResources::BERT_QA,
|
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
|
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
false, //lowercase
|
false, //lowercase
|
||||||
false,
|
false,
|
||||||
@ -54,11 +52,9 @@ fn qa_load_model(iters: u64) -> Duration {
|
|||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
|
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||||
BertConfigResources::BERT_QA,
|
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
|
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
false, //lowercase
|
false, //lowercase
|
||||||
false,
|
false,
|
||||||
|
@ -19,21 +19,21 @@ use rust_bert::gpt_neo::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Set-up model resources
|
// Set-up model resources
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoConfigResources::GPT_NEO_125M,
|
GptNeoConfigResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoVocabResources::GPT_NEO_125M,
|
GptNeoVocabResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoMergesResources::GPT_NEO_125M,
|
GptNeoMergesResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoModelResources::GPT_NEO_125M,
|
GptNeoModelResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
|
@ -19,21 +19,21 @@ use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGeneration
|
|||||||
use rust_bert::reformer::{
|
use rust_bert::reformer::{
|
||||||
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
|
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
|
@ -16,21 +16,21 @@ extern crate anyhow;
|
|||||||
|
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetModelResources::XLNET_BASE_CASED,
|
XLNetModelResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ extern crate anyhow;
|
|||||||
use rust_bert::bert::{
|
use rust_bert::bert::{
|
||||||
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
|
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -23,12 +23,9 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let weights_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let weights_path = weights_resource.get_local_path()?;
|
let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -4,23 +4,23 @@ use rust_bert::deberta::{
|
|||||||
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
||||||
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaModelResources::DEBERTA_BASE_MNLI,
|
DebertaModelResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -17,17 +17,15 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::pipelines::question_answering::{
|
use rust_bert::pipelines::question_answering::{
|
||||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Set-up Question Answering model
|
// Set-up Question Answering model
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
|
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||||
BertConfigResources::BERT_QA,
|
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
|
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -20,24 +20,18 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::pipelines::question_answering::{
|
use rust_bert::pipelines::question_answering::{
|
||||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Set-up Question Answering model
|
// Set-up Question Answering model
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Longformer,
|
ModelType::Longformer,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
|
||||||
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
|
||||||
)),
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
Some(RemoteResource::from_pretrained(
|
||||||
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
|
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
|
|
||||||
)),
|
|
||||||
Some(Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
||||||
))),
|
)),
|
||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
false,
|
false,
|
||||||
|
@ -15,17 +15,17 @@ extern crate anyhow;
|
|||||||
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
|
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
|
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Set-up classifier
|
// Set-up classifier
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetConfigResources::BASE_SST2,
|
FNetConfigResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetVocabResources::BASE_SST2,
|
FNetVocabResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetModelResources::BASE_SST2,
|
FNetModelResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -16,20 +16,20 @@ use rust_bert::bart::{
|
|||||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartConfigResources::DISTILBART_CNN_6_6,
|
BartConfigResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartVocabResources::DISTILBART_CNN_6_6,
|
BartVocabResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartMergesResources::DISTILBART_CNN_6_6,
|
BartMergesResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartModelResources::DISTILBART_CNN_6_6,
|
BartModelResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -15,17 +15,17 @@ extern crate anyhow;
|
|||||||
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusConfigResources::CNN_DAILYMAIL,
|
PegasusConfigResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusVocabResources::CNN_DAILYMAIL,
|
PegasusVocabResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusModelResources::CNN_DAILYMAIL,
|
PegasusModelResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -17,17 +17,17 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
|
|||||||
use rust_bert::prophetnet::{
|
use rust_bert::prophetnet::{
|
||||||
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -14,16 +14,14 @@ extern crate anyhow;
|
|||||||
|
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
|
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
|
||||||
let vocab_resource =
|
let weights_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
|
||||||
let weights_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
|
|
||||||
let summarization_config = SummarizationConfig::new(
|
let summarization_config = SummarizationConfig::new(
|
||||||
ModelType::T5,
|
ModelType::T5,
|
||||||
weights_resource,
|
weights_resource,
|
||||||
|
@ -16,21 +16,15 @@ use rust_bert::pipelines::ner::NERModel;
|
|||||||
use rust_bert::pipelines::token_classification::{
|
use rust_bert::pipelines::token_classification::{
|
||||||
LabelAggregationOption, TokenClassificationConfig,
|
LabelAggregationOption, TokenClassificationConfig,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Load a configuration
|
// Load a configuration
|
||||||
let config = TokenClassificationConfig::new(
|
let config = TokenClassificationConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||||
BertModelResources::BERT_NER,
|
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||||
)),
|
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BertConfigResources::BERT_NER,
|
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BertVocabResources::BERT_NER,
|
|
||||||
)),
|
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
false, //lowercase
|
false, //lowercase
|
||||||
false,
|
false,
|
||||||
|
@ -18,22 +18,14 @@ use rust_bert::m2m_100::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
|
||||||
M2M100ModelResources::M2M100_418M,
|
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
|
||||||
M2M100ConfigResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100VocabResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
||||||
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
||||||
|
@ -19,22 +19,14 @@ use rust_bert::marian::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2CHINESE);
|
||||||
MarianModelResources::ENGLISH2CHINESE,
|
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2CHINESE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2CHINESE);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2CHINESE);
|
||||||
MarianConfigResources::ENGLISH2CHINESE,
|
|
||||||
));
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MarianVocabResources::ENGLISH2CHINESE,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MarianSpmResources::ENGLISH2CHINESE,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = MarianSourceLanguages::ENGLISH2CHINESE;
|
let source_languages = MarianSourceLanguages::ENGLISH2CHINESE;
|
||||||
let target_languages = MarianTargetLanguages::ENGLISH2CHINESE;
|
let target_languages = MarianTargetLanguages::ENGLISH2CHINESE;
|
||||||
|
@ -18,22 +18,16 @@ use rust_bert::mbart::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(MBartModelResources::MBART50_MANY_TO_MANY);
|
||||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
let config_resource =
|
||||||
));
|
RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
|
||||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
let merges_resource =
|
||||||
));
|
RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
|
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
|
||||||
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
|
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
|
||||||
|
@ -14,19 +14,15 @@ extern crate anyhow;
|
|||||||
|
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let model_resource =
|
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
|
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
|
||||||
let config_resource =
|
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
|
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
|
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
|
|
||||||
|
|
||||||
let source_languages = [
|
let source_languages = [
|
||||||
Language::English,
|
Language::English,
|
||||||
|
@ -24,19 +24,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::AlbertTokenizer;
|
//! use rust_tokenizers::tokenizer::AlbertTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -17,10 +17,6 @@ use crate::bart::encoder::BartEncoder;
|
|||||||
use crate::common::activations::Activation;
|
use crate::common::activations::Activation;
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::common::kind::get_negative_infinity;
|
use crate::common::kind::get_negative_infinity;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::gpt2::{
|
|
||||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||||
PreparedInput, PrivateLanguageGenerator,
|
PreparedInput, PrivateLanguageGenerator,
|
||||||
@ -1028,43 +1024,10 @@ impl BartGenerator {
|
|||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let merges_path = generate_config.merges_resource.get_local_path()?;
|
||||||
{
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let merges_resource = if generate_config.merges_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
|
|
||||||
} else {
|
|
||||||
generate_config.merges_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
@ -1293,7 +1256,7 @@ mod test {
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
resources::{RemoteResource, Resource},
|
resources::{RemoteResource, ResourceProvider},
|
||||||
Config,
|
Config,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1302,8 +1265,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // compilation is enough, no need to run
|
#[ignore] // compilation is enough, no need to run
|
||||||
fn bart_model_send() {
|
fn bart_model_send() {
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
|
||||||
let config_path = config_resource.get_local_path().expect("");
|
let config_path = config_resource.get_local_path().expect("");
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
|
@ -19,22 +19,22 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
|
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let merges_resource = Resource::Local(LocalResource {
|
//! let merges_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -1215,7 +1215,7 @@ mod test {
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
resources::{RemoteResource, Resource},
|
resources::{RemoteResource, ResourceProvider},
|
||||||
Config,
|
Config,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1224,8 +1224,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // compilation is enough, no need to run
|
#[ignore] // compilation is enough, no need to run
|
||||||
fn bert_model_send() {
|
fn bert_model_send() {
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path().expect("");
|
let config_path = config_resource.get_local_path().expect("");
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
|
@ -24,19 +24,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -4,8 +4,9 @@ use thiserror::Error;
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum RustBertError {
|
pub enum RustBertError {
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
#[error("Endpoint not available error: {0}")]
|
#[error("Endpoint not available error: {0}")]
|
||||||
FileDownloadError(String),
|
FileDownloadError(#[from] cached_path::Error),
|
||||||
|
|
||||||
#[error("IO error: {0}")]
|
#[error("IO error: {0}")]
|
||||||
IOError(String),
|
IOError(String),
|
||||||
@ -23,12 +24,6 @@ pub enum RustBertError {
|
|||||||
ValueError(String),
|
ValueError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<cached_path::Error> for RustBertError {
|
|
||||||
fn from(error: cached_path::Error) -> Self {
|
|
||||||
RustBertError::FileDownloadError(error.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<std::io::Error> for RustBertError {
|
impl From<std::io::Error> for RustBertError {
|
||||||
fn from(error: std::io::Error) -> Self {
|
fn from(error: std::io::Error) -> Self {
|
||||||
RustBertError::IOError(error.to_string())
|
RustBertError::IOError(error.to_string())
|
||||||
|
@ -1,197 +0,0 @@
|
|||||||
//! # Resource definitions for model weights, vocabularies and configuration files
|
|
||||||
//!
|
|
||||||
//! This crate relies on the concept of Resources to access the files used by the models.
|
|
||||||
//! This includes:
|
|
||||||
//! - model weights
|
|
||||||
//! - configuration files
|
|
||||||
//! - vocabularies
|
|
||||||
//! - (optional) merges files for BPE-based tokenizers
|
|
||||||
//!
|
|
||||||
//! These are expected in the pipelines configurations or are used as utilities to reference to the
|
|
||||||
//! resource location. Two types of resources exist:
|
|
||||||
//! - LocalResource: points to a local file
|
|
||||||
//! - RemoteResource: points to a remote file via a URL and a local cached file
|
|
||||||
//!
|
|
||||||
//! For both types of resources, the local location of teh file can be retrieved using
|
|
||||||
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
|
|
||||||
//! or local resource. Default implementations for a number of `RemoteResources` are available as
|
|
||||||
//! pre-trained models in each model module.
|
|
||||||
|
|
||||||
use crate::common::error::RustBertError;
|
|
||||||
use cached_path::{Cache, Options, ProgressBar};
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use std::env;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
|
|
||||||
extern crate dirs;
|
|
||||||
|
|
||||||
/// # Resource Enum pointing to model, configuration or vocabulary resources
|
|
||||||
/// Can be of type:
|
|
||||||
/// - LocalResource
|
|
||||||
/// - RemoteResource
|
|
||||||
#[derive(PartialEq, Clone)]
|
|
||||||
pub enum Resource {
|
|
||||||
Local(LocalResource),
|
|
||||||
Remote(RemoteResource),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Resource {
|
|
||||||
/// Gets the local path for a given resource.
|
|
||||||
///
|
|
||||||
/// If the resource is a remote resource, it is downloaded and cached. Then the path
|
|
||||||
/// to the local cache is returned.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `PathBuf` pointing to the resource file
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use rust_bert::resources::{LocalResource, Resource};
|
|
||||||
/// use std::path::PathBuf;
|
|
||||||
/// let config_resource = Resource::Local(LocalResource {
|
|
||||||
/// local_path: PathBuf::from("path/to/config.json"),
|
|
||||||
/// });
|
|
||||||
/// let config_path = config_resource.get_local_path();
|
|
||||||
/// ```
|
|
||||||
pub fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
|
||||||
match self {
|
|
||||||
Resource::Local(resource) => Ok(resource.local_path.clone()),
|
|
||||||
Resource::Remote(resource) => {
|
|
||||||
let cached_path = CACHE.cached_path_with_options(
|
|
||||||
&resource.url,
|
|
||||||
&Options::default().subdir(&resource.cache_subdir),
|
|
||||||
)?;
|
|
||||||
Ok(cached_path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Local resource
|
|
||||||
#[derive(PartialEq, Clone)]
|
|
||||||
pub struct LocalResource {
|
|
||||||
/// Local path for the resource
|
|
||||||
pub local_path: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Remote resource
|
|
||||||
#[derive(PartialEq, Clone)]
|
|
||||||
pub struct RemoteResource {
|
|
||||||
/// Remote path/url for the resource
|
|
||||||
pub url: String,
|
|
||||||
/// Local subdirectory of the cache root where this resource is saved
|
|
||||||
pub cache_subdir: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RemoteResource {
|
|
||||||
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
|
|
||||||
/// download the resource (only declares the remote and local locations)
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `url` - `&str` Location of the remote resource
|
|
||||||
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `RemoteResource` RemoteResource object
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
|
||||||
/// let config_resource = Resource::Remote(RemoteResource::new(
|
|
||||||
/// "configs",
|
|
||||||
/// "http://config_json_location",
|
|
||||||
/// ));
|
|
||||||
/// ```
|
|
||||||
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
|
|
||||||
RemoteResource {
|
|
||||||
url: url.to_string(),
|
|
||||||
cache_subdir: cache_subdir.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
|
|
||||||
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
|
|
||||||
/// the remote and local locations)
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `RemoteResource` RemoteResource object
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
|
||||||
/// "distilbert-sst2",
|
|
||||||
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
|
|
||||||
/// )));
|
|
||||||
/// ```
|
|
||||||
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
|
|
||||||
let cache_subdir = name_url_tuple.0.to_string();
|
|
||||||
let url = name_url_tuple.1.to_string();
|
|
||||||
RemoteResource { url, cache_subdir }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
|
||||||
/// # Global cache directory
|
|
||||||
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
|
|
||||||
/// location. Otherwise defaults to `~/.cache/.rustbert`.
|
|
||||||
pub static ref CACHE: Cache = Cache::builder()
|
|
||||||
.dir(_get_cache_directory())
|
|
||||||
.progress_bar(Some(ProgressBar::Light))
|
|
||||||
.build().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _get_cache_directory() -> PathBuf {
|
|
||||||
match env::var("RUSTBERT_CACHE") {
|
|
||||||
Ok(value) => PathBuf::from(value),
|
|
||||||
Err(_) => {
|
|
||||||
let mut home = dirs::home_dir().unwrap();
|
|
||||||
home.push(".cache");
|
|
||||||
home.push(".rustbert");
|
|
||||||
home
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[deprecated(
|
|
||||||
since = "0.9.1",
|
|
||||||
note = "Please use `Resource.get_local_path()` instead"
|
|
||||||
)]
|
|
||||||
/// # (Download) the resource and return a path to its local path
|
|
||||||
/// This function will download remote resource to their local path if they do not exist yet.
|
|
||||||
/// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource.
|
|
||||||
/// For `LocalResource` only the resource path is returned.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `resource` - Pointer to the `&Resource` to optionally download and get the local path.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `&PathBuf` Local path for the resource
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
|
||||||
/// "distilbert-sst2/model.ot",
|
|
||||||
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
|
|
||||||
/// )));
|
|
||||||
/// let local_path = model_resource.get_local_path();
|
|
||||||
/// ```
|
|
||||||
pub fn download_resource(resource: &Resource) -> Result<PathBuf, RustBertError> {
|
|
||||||
resource.get_local_path()
|
|
||||||
}
|
|
32
src/common/resources/local.rs
Normal file
32
src/common/resources/local.rs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
use crate::common::error::RustBertError;
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// # Local resource
|
||||||
|
#[derive(PartialEq, Clone)]
|
||||||
|
pub struct LocalResource {
|
||||||
|
/// Local path for the resource
|
||||||
|
pub local_path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResourceProvider for LocalResource {
|
||||||
|
/// Gets the path for a local resource.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `PathBuf` pointing to the resource file
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
|
/// use std::path::PathBuf;
|
||||||
|
/// let config_resource = LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/config.json"),
|
||||||
|
/// };
|
||||||
|
/// let config_path = config_resource.get_local_path();
|
||||||
|
/// ```
|
||||||
|
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||||
|
Ok(self.local_path.clone())
|
||||||
|
}
|
||||||
|
}
|
50
src/common/resources/mod.rs
Normal file
50
src/common/resources/mod.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
//! # Resource definitions for model weights, vocabularies and configuration files
|
||||||
|
//!
|
||||||
|
//! This crate relies on the concept of Resources to access the files used by the models.
|
||||||
|
//! This includes:
|
||||||
|
//! - model weights
|
||||||
|
//! - configuration files
|
||||||
|
//! - vocabularies
|
||||||
|
//! - (optional) merges files for BPE-based tokenizers
|
||||||
|
//!
|
||||||
|
//! These are expected in the pipelines configurations or are used as utilities to reference to the
|
||||||
|
//! resource location. Two types of resources are pre-defined:
|
||||||
|
//! - LocalResource: points to a local file
|
||||||
|
//! - RemoteResource: points to a remote file via a URL
|
||||||
|
//!
|
||||||
|
//! For both types of resources, the local location of the file can be retrieved using
|
||||||
|
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
|
||||||
|
//! or local resource. Default implementations for a number of `RemoteResources` are available as
|
||||||
|
//! pre-trained models in each model module.
|
||||||
|
|
||||||
|
mod local;
|
||||||
|
|
||||||
|
use crate::common::error::RustBertError;
|
||||||
|
pub use local::LocalResource;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
|
||||||
|
pub trait ResourceProvider {
|
||||||
|
/// Provides the local path for a resource.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `PathBuf` pointing to the resource file
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
|
/// use std::path::PathBuf;
|
||||||
|
/// let config_resource = LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/config.json"),
|
||||||
|
/// };
|
||||||
|
/// let config_path = config_resource.get_local_path();
|
||||||
|
/// ```
|
||||||
|
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
mod remote;
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub use remote::RemoteResource;
|
122
src/common/resources/remote.rs
Normal file
122
src/common/resources/remote.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
use super::*;
|
||||||
|
use crate::common::error::RustBertError;
|
||||||
|
use cached_path::{Cache, Options, ProgressBar};
|
||||||
|
use dirs::cache_dir;
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// # Remote resource that will be downloaded and cached locally on demand
|
||||||
|
#[derive(PartialEq, Clone)]
|
||||||
|
pub struct RemoteResource {
|
||||||
|
/// Remote path/url for the resource
|
||||||
|
pub url: String,
|
||||||
|
/// Local subdirectory of the cache root where this resource is saved
|
||||||
|
pub cache_subdir: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteResource {
|
||||||
|
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
|
||||||
|
/// download the resource (only declares the remote and local locations)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `url` - `&str` Location of the remote resource
|
||||||
|
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `RemoteResource` RemoteResource object
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use rust_bert::resources::RemoteResource;
|
||||||
|
/// let config_resource = RemoteResource::new(
|
||||||
|
/// "configs",
|
||||||
|
/// "http://config_json_location",
|
||||||
|
/// );
|
||||||
|
/// ```
|
||||||
|
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
|
||||||
|
RemoteResource {
|
||||||
|
url: url.to_string(),
|
||||||
|
cache_subdir: cache_subdir.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
|
||||||
|
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
|
||||||
|
/// the remote and local locations)
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `RemoteResource` RemoteResource object
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use rust_bert::resources::RemoteResource;
|
||||||
|
/// let model_resource = RemoteResource::from_pretrained((
|
||||||
|
/// "distilbert-sst2",
|
||||||
|
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
|
||||||
|
/// ));
|
||||||
|
/// ```
|
||||||
|
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
|
||||||
|
let cache_subdir = name_url_tuple.0.to_string();
|
||||||
|
let url = name_url_tuple.1.to_string();
|
||||||
|
RemoteResource { url, cache_subdir }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResourceProvider for RemoteResource {
|
||||||
|
/// Gets the local path for a remote resource.
|
||||||
|
///
|
||||||
|
/// The remote resource is downloaded and cached. Then the path
|
||||||
|
/// to the local cache is returned.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `PathBuf` pointing to the resource file
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
|
/// use std::path::PathBuf;
|
||||||
|
/// let config_resource = LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/config.json"),
|
||||||
|
/// };
|
||||||
|
/// let config_path = config_resource.get_local_path();
|
||||||
|
/// ```
|
||||||
|
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||||
|
let cached_path = CACHE
|
||||||
|
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
|
||||||
|
Ok(cached_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
/// # Global cache directory
|
||||||
|
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
|
||||||
|
/// location. Otherwise defaults to `$XDG_CACHE_HOME/.rustbert`, or corresponding user cache for
|
||||||
|
/// the current system.
|
||||||
|
pub static ref CACHE: Cache = Cache::builder()
|
||||||
|
.dir(_get_cache_directory())
|
||||||
|
.progress_bar(Some(ProgressBar::Light))
|
||||||
|
.build().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _get_cache_directory() -> PathBuf {
|
||||||
|
match std::env::var("RUSTBERT_CACHE") {
|
||||||
|
Ok(value) => PathBuf::from(value),
|
||||||
|
Err(_) => {
|
||||||
|
let mut home = cache_dir().unwrap();
|
||||||
|
home.push(".rustbert");
|
||||||
|
home
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -23,22 +23,22 @@
|
|||||||
//! DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
//! DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
||||||
//! DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
//! DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
||||||
//! };
|
//! };
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::DeBERTaTokenizer;
|
//! use rust_tokenizers::tokenizer::DeBERTaTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let config_resource = RemoteResource::from_pretrained(
|
||||||
//! DebertaConfigResources::DEBERTA_BASE_MNLI,
|
//! DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
//! ));
|
//! );
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let vocab_resource = RemoteResource::from_pretrained(
|
||||||
//! DebertaVocabResources::DEBERTA_BASE_MNLI,
|
//! DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||||
//! ));
|
//! );
|
||||||
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let merges_resource = RemoteResource::from_pretrained(
|
||||||
//! DebertaMergesResources::DEBERTA_BASE_MNLI,
|
//! DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||||
//! ));
|
//! );
|
||||||
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let weights_resource = RemoteResource::from_pretrained(
|
||||||
//! DebertaModelResources::DEBERTA_BASE_MNLI,
|
//! DebertaModelResources::DEBERTA_BASE_MNLI,
|
||||||
//! ));
|
//! );
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -25,19 +25,19 @@
|
|||||||
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
|
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
|
||||||
//! DistilBertModelResources, DistilBertVocabResources,
|
//! DistilBertModelResources, DistilBertVocabResources,
|
||||||
//! };
|
//! };
|
||||||
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -27,19 +27,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -1029,7 +1029,7 @@ mod test {
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
resources::{RemoteResource, Resource},
|
resources::{RemoteResource, ResourceProvider},
|
||||||
Config,
|
Config,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1038,8 +1038,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // compilation is enough, no need to run
|
#[ignore] // compilation is enough, no need to run
|
||||||
fn fnet_model_send() {
|
fn fnet_model_send() {
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
|
||||||
let config_path = config_resource.get_local_path().expect("");
|
let config_path = config_resource.get_local_path().expect("");
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
|
@ -22,19 +22,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
|
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
|
||||||
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
|
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -19,22 +19,22 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
|
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let merges_resource = Resource::Local(LocalResource {
|
//! let merges_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -22,20 +22,20 @@
|
|||||||
//! };
|
//! };
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use tch::Device;
|
//! use tch::Device;
|
||||||
//!
|
//!
|
||||||
//! fn main() -> anyhow::Result<()> {
|
//! fn main() -> anyhow::Result<()> {
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! GptNeoConfigResources::GPT_NEO_1_3B,
|
//! GptNeoConfigResources::GPT_NEO_1_3B,
|
||||||
//! ));
|
//! ));
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! GptNeoVocabResources::GPT_NEO_1_3B,
|
//! GptNeoVocabResources::GPT_NEO_1_3B,
|
||||||
//! ));
|
//! ));
|
||||||
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! GptNeoMergesResources::GPT_NEO_1_3B,
|
//! GptNeoMergesResources::GPT_NEO_1_3B,
|
||||||
//! ));
|
//! ));
|
||||||
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! GptNeoModelResources::GPT_NEO_1_3B,
|
//! GptNeoModelResources::GPT_NEO_1_3B,
|
||||||
//! ));
|
//! ));
|
||||||
//!
|
//!
|
||||||
|
@ -27,24 +27,24 @@
|
|||||||
//! use rust_bert::pipelines::question_answering::{
|
//! use rust_bert::pipelines::question_answering::{
|
||||||
//! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
//! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
//! };
|
//! };
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::{RemoteResource};
|
||||||
//!
|
//!
|
||||||
//! fn main() -> anyhow::Result<()> {
|
//! fn main() -> anyhow::Result<()> {
|
||||||
//! // Set-up Question Answering model
|
//! // Set-up Question Answering model
|
||||||
//! let config = QuestionAnsweringConfig::new(
|
//! let config = QuestionAnsweringConfig::new(
|
||||||
//! ModelType::Longformer,
|
//! ModelType::Longformer,
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(
|
//! RemoteResource::from_pretrained(
|
||||||
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
||||||
//! )),
|
//! ),
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(
|
//! RemoteResource::from_pretrained(
|
||||||
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
|
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
|
||||||
//! )),
|
//! ),
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(
|
//! RemoteResource::from_pretrained(
|
||||||
//! LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
|
//! LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
|
||||||
//! )),
|
//! ),
|
||||||
//! Some(Resource::Remote(RemoteResource::from_pretrained(
|
//! Some(RemoteResource::from_pretrained(
|
||||||
//! LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
//! LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
||||||
//! ))),
|
//! )),
|
||||||
//! false,
|
//! false,
|
||||||
//! None,
|
//! None,
|
||||||
//! false,
|
//! false,
|
||||||
|
@ -10,9 +10,6 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::gpt2::{
|
|
||||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::m2m_100::decoder::M2M100Decoder;
|
use crate::m2m_100::decoder::M2M100Decoder;
|
||||||
use crate::m2m_100::encoder::M2M100Encoder;
|
use crate::m2m_100::encoder::M2M100Encoder;
|
||||||
use crate::m2m_100::LayerState;
|
use crate::m2m_100::LayerState;
|
||||||
@ -25,7 +22,6 @@ use crate::pipelines::generation_utils::{
|
|||||||
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
||||||
};
|
};
|
||||||
use crate::pipelines::translation::Language;
|
use crate::pipelines::translation::Language;
|
||||||
use crate::resources::{RemoteResource, Resource};
|
|
||||||
use crate::{Config, RustBertError};
|
use crate::{Config, RustBertError};
|
||||||
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
|
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
|
||||||
@ -618,51 +614,10 @@ impl M2M100Generator {
|
|||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let merges_path = generate_config.merges_resource.get_local_path()?;
|
||||||
{
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100ModelResources::M2M100_418M,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100ConfigResources::M2M100_418M,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100VocabResources::M2M100_418M,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let merges_resource = if generate_config.merges_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_418M,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.merges_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
@ -889,7 +844,7 @@ mod test {
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
resources::{RemoteResource, Resource},
|
resources::{RemoteResource, ResourceProvider},
|
||||||
Config,
|
Config,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -898,7 +853,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // compilation is enough, no need to run
|
#[ignore] // compilation is enough, no need to run
|
||||||
fn mbart_model_send() {
|
fn mbart_model_send() {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
M2M100ConfigResources::M2M100_418M,
|
M2M100ConfigResources::M2M100_418M,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path().expect("");
|
let config_path = config_resource.get_local_path().expect("");
|
||||||
|
@ -20,22 +20,22 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::m2m_100::{M2M100Config, M2M100Model};
|
//! use rust_bert::m2m_100::{M2M100Config, M2M100Model};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::M2M100Tokenizer;
|
//! use rust_tokenizers::tokenizer::M2M100Tokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let merges_resource = Resource::Local(LocalResource {
|
//! let merges_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -21,22 +21,22 @@
|
|||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||||
//! use rust_bert::marian::MarianForConditionalGeneration;
|
//! use rust_bert::marian::MarianForConditionalGeneration;
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::MarianTokenizer;
|
//! use rust_tokenizers::tokenizer::MarianTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.json"),
|
//! local_path: PathBuf::from("path/to/vocab.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let sentence_piece_resource = Resource::Local(LocalResource {
|
//! let sentence_piece_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let spiece_path = sentence_piece_resource.get_local_path()?;
|
//! let spiece_path = sentence_piece_resource.get_local_path()?;
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
|
|
||||||
use crate::bart::BartModelOutput;
|
use crate::bart::BartModelOutput;
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
|
|
||||||
use crate::mbart::decoder::MBartDecoder;
|
use crate::mbart::decoder::MBartDecoder;
|
||||||
use crate::mbart::encoder::MBartEncoder;
|
use crate::mbart::encoder::MBartEncoder;
|
||||||
use crate::mbart::LayerState;
|
use crate::mbart::LayerState;
|
||||||
@ -24,7 +23,6 @@ use crate::pipelines::generation_utils::{
|
|||||||
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
|
||||||
};
|
};
|
||||||
use crate::pipelines::translation::Language;
|
use crate::pipelines::translation::Language;
|
||||||
use crate::resources::{RemoteResource, Resource};
|
|
||||||
use crate::{Activation, Config, RustBertError};
|
use crate::{Activation, Config, RustBertError};
|
||||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
|
use rust_tokenizers::vocab::{MBart50Vocab, Vocab};
|
||||||
@ -839,40 +837,9 @@ impl MBartGenerator {
|
|||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(generate_config: GenerateConfig) -> Result<MBartGenerator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<MBartGenerator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
@ -1099,7 +1066,7 @@ mod test {
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
resources::{RemoteResource, Resource},
|
resources::{RemoteResource, ResourceProvider},
|
||||||
Config,
|
Config,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1108,7 +1075,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // compilation is enough, no need to run
|
#[ignore] // compilation is enough, no need to run
|
||||||
fn mbart_model_send() {
|
fn mbart_model_send() {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path().expect("");
|
let config_path = config_resource.get_local_path().expect("");
|
||||||
|
@ -19,19 +19,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::mbart::{MBartConfig, MBartModel};
|
//! use rust_bert::mbart::{MBartConfig, MBartModel};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::MBart50Tokenizer;
|
//! use rust_tokenizers::tokenizer::MBart50Tokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -24,19 +24,19 @@
|
|||||||
//! MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
|
//! MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM,
|
||||||
//! MobileBertModelResources, MobileBertVocabResources,
|
//! MobileBertModelResources, MobileBertVocabResources,
|
||||||
//! };
|
//! };
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
//! use rust_tokenizers::tokenizer::BertTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let config_resource = RemoteResource::from_pretrained(
|
||||||
//! MobileBertConfigResources::MOBILEBERT_UNCASED,
|
//! MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
//! ));
|
//! );
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let vocab_resource = RemoteResource::from_pretrained(
|
||||||
//! MobileBertVocabResources::MOBILEBERT_UNCASED,
|
//! MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
//! ));
|
//! );
|
||||||
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let weights_resource = RemoteResource::from_pretrained(
|
||||||
//! MobileBertModelResources::MOBILEBERT_UNCASED,
|
//! MobileBertModelResources::MOBILEBERT_UNCASED,
|
||||||
//! ));
|
//! );
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -18,22 +18,22 @@
|
|||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::gpt2::Gpt2Config;
|
//! use rust_bert::gpt2::Gpt2Config;
|
||||||
//! use rust_bert::openai_gpt::OpenAiGptModel;
|
//! use rust_bert::openai_gpt::OpenAiGptModel;
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::OpenAiGptTokenizer;
|
//! use rust_tokenizers::tokenizer::OpenAiGptTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let merges_resource = Resource::Local(LocalResource {
|
//! let merges_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -15,10 +15,7 @@
|
|||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::common::embeddings::process_ids_embeddings_pair;
|
use crate::common::embeddings::process_ids_embeddings_pair;
|
||||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
use crate::gpt2::Gpt2Config;
|
||||||
use crate::gpt2::{
|
|
||||||
Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::openai_gpt::transformer::Block;
|
use crate::openai_gpt::transformer::Block;
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||||
@ -471,51 +468,10 @@ impl OpenAIGenerator {
|
|||||||
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
|
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let merges_path = generate_config.merges_resource.get_local_path()?;
|
||||||
{
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
OpenAiGptModelResources::GPT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
OpenAiGptConfigResources::GPT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
OpenAiGptVocabResources::GPT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let merges_resource = if generate_config.merges_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
OpenAiGptMergesResources::GPT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.merges_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
let mut var_store = nn::VarStore::new(device);
|
let mut var_store = nn::VarStore::new(device);
|
||||||
|
@ -19,19 +19,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::pegasus::{PegasusConfig, PegasusModel};
|
//! use rust_bert::pegasus::{PegasusConfig, PegasusModel};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::PegasusTokenizer;
|
//! use rust_tokenizers::tokenizer::PegasusTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
|
|
||||||
use crate::bart::BartModelOutput;
|
use crate::bart::BartModelOutput;
|
||||||
use crate::common::kind::get_negative_infinity;
|
use crate::common::kind::get_negative_infinity;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
|
|
||||||
use crate::mbart::MBartConfig;
|
use crate::mbart::MBartConfig;
|
||||||
use crate::pegasus::decoder::PegasusDecoder;
|
use crate::pegasus::decoder::PegasusDecoder;
|
||||||
use crate::pegasus::encoder::PegasusEncoder;
|
use crate::pegasus::encoder::PegasusEncoder;
|
||||||
@ -601,40 +599,9 @@ impl PegasusConditionalGenerator {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
generate_config: GenerateConfig,
|
generate_config: GenerateConfig,
|
||||||
) -> Result<PegasusConditionalGenerator, RustBertError> {
|
) -> Result<PegasusConditionalGenerator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
PegasusModelResources::CNN_DAILYMAIL,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
PegasusConfigResources::CNN_DAILYMAIL,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
PegasusVocabResources::CNN_DAILYMAIL,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
|
@ -45,17 +45,21 @@
|
|||||||
//! The authors of this repository are not responsible for any generation
|
//! The authors of this repository are not responsible for any generation
|
||||||
//! from the 3rd party utilization of the pretrained system.
|
//! from the 3rd party utilization of the pretrained system.
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
use crate::gpt2::GPT2Generator;
|
||||||
use crate::gpt2::{
|
|
||||||
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||||
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tch::{Device, Kind, Tensor};
|
use tch::{Device, Kind, Tensor};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
/// # Configuration for multi-turn classification
|
/// # Configuration for multi-turn classification
|
||||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||||
/// different set of default parameters and sets the device to place the model on.
|
/// different set of default parameters and sets the device to place the model on.
|
||||||
@ -63,13 +67,13 @@ pub struct ConversationConfig {
|
|||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: DialoGPT-medium)
|
/// Model weights resource (default: DialoGPT-medium)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: DialoGPT-medium)
|
/// Config resource (default: DialoGPT-medium)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: DialoGPT-medium)
|
/// Vocab resource (default: DialoGPT-medium)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: DialoGPT-medium)
|
/// Merges resource (default: DialoGPT-medium)
|
||||||
pub merges_resource: Resource,
|
pub merges_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Minimum sequence length (default: 0)
|
/// Minimum sequence length (default: 0)
|
||||||
pub min_length: i64,
|
pub min_length: i64,
|
||||||
/// Maximum sequence length (default: 20)
|
/// Maximum sequence length (default: 20)
|
||||||
@ -104,20 +108,21 @@ pub struct ConversationConfig {
|
|||||||
pub device: Device,
|
pub device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for ConversationConfig {
|
impl Default for ConversationConfig {
|
||||||
fn default() -> ConversationConfig {
|
fn default() -> ConversationConfig {
|
||||||
ConversationConfig {
|
ConversationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
||||||
)),
|
)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
||||||
)),
|
)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2VocabResources::DIALOGPT_MEDIUM,
|
Gpt2VocabResources::DIALOGPT_MEDIUM,
|
||||||
)),
|
)),
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2MergesResources::DIALOGPT_MEDIUM,
|
Gpt2MergesResources::DIALOGPT_MEDIUM,
|
||||||
)),
|
)),
|
||||||
min_length: 0,
|
min_length: 0,
|
||||||
|
@ -73,10 +73,7 @@ use tch::{no_grad, Device, Tensor};
|
|||||||
|
|
||||||
use crate::bart::LayerState as BartLayerState;
|
use crate::bart::LayerState as BartLayerState;
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
use crate::common::resources::ResourceProvider;
|
||||||
use crate::gpt2::{
|
|
||||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::gpt_neo::LayerState as GPTNeoLayerState;
|
use crate::gpt_neo::LayerState as GPTNeoLayerState;
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||||
InternalGenerateOptions, PrivateLanguageGenerator,
|
InternalGenerateOptions, PrivateLanguageGenerator,
|
||||||
@ -89,18 +86,24 @@ use crate::xlnet::LayerState as XLNetLayerState;
|
|||||||
use self::ordered_float::OrderedFloat;
|
use self::ordered_float::OrderedFloat;
|
||||||
use crate::pipelines::common::TokenizerOption;
|
use crate::pipelines::common::TokenizerOption;
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
extern crate ordered_float;
|
extern crate ordered_float;
|
||||||
|
|
||||||
/// # Configuration for text generation
|
/// # Configuration for text generation
|
||||||
pub struct GenerateConfig {
|
pub struct GenerateConfig {
|
||||||
/// Model weights resource (default: pretrained GPT2 model)
|
/// Model weights resource (default: pretrained GPT2 model)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained GPT2 model)
|
/// Config resource (default: pretrained GPT2 model)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained GPT2 model)
|
/// Vocab resource (default: pretrained GPT2 model)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: pretrained GPT2 model)
|
/// Merges resource (default: pretrained GPT2 model)
|
||||||
pub merges_resource: Resource,
|
pub merges_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Minimum sequence length (default: 0)
|
/// Minimum sequence length (default: 0)
|
||||||
pub min_length: i64,
|
pub min_length: i64,
|
||||||
/// Maximum sequence length (default: 20)
|
/// Maximum sequence length (default: 20)
|
||||||
@ -133,21 +136,14 @@ pub struct GenerateConfig {
|
|||||||
pub device: Device,
|
pub device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for GenerateConfig {
|
impl Default for GenerateConfig {
|
||||||
fn default() -> GenerateConfig {
|
fn default() -> GenerateConfig {
|
||||||
GenerateConfig {
|
GenerateConfig {
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
|
||||||
Gpt2ModelResources::GPT2,
|
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
|
||||||
)),
|
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
|
||||||
Gpt2ConfigResources::GPT2,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2VocabResources::GPT2,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2MergesResources::GPT2,
|
|
||||||
)),
|
|
||||||
min_length: 0,
|
min_length: 0,
|
||||||
max_length: 20,
|
max_length: 20,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
|
@ -78,7 +78,7 @@
|
|||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//! use rust_bert::pipelines::ner::NERModel;
|
//! use rust_bert::pipelines::ner::NERModel;
|
||||||
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
|
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use rust_bert::roberta::{
|
//! use rust_bert::roberta::{
|
||||||
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
|
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
|
||||||
//! };
|
//! };
|
||||||
@ -87,13 +87,13 @@
|
|||||||
//! # fn main() -> anyhow::Result<()> {
|
//! # fn main() -> anyhow::Result<()> {
|
||||||
//! let ner_config = TokenClassificationConfig {
|
//! let ner_config = TokenClassificationConfig {
|
||||||
//! model_type: ModelType::XLMRoberta,
|
//! model_type: ModelType::XLMRoberta,
|
||||||
//! model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
//! model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
|
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
|
||||||
//! )),
|
//! )),
|
||||||
//! config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
//! config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
||||||
//! )),
|
//! )),
|
||||||
//! vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
//! vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
|
//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
|
||||||
//! )),
|
//! )),
|
||||||
//! lower_case: false,
|
//! lower_case: false,
|
||||||
|
@ -82,16 +82,20 @@
|
|||||||
//! To run the pipeline for another language, change the POSModel configuration from its default (see the NER pipeline for an illustration).
|
//! To run the pipeline for another language, change the POSModel configuration from its default (see the NER pipeline for an illustration).
|
||||||
|
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::mobilebert::{
|
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
|
||||||
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
|
|
||||||
};
|
|
||||||
use crate::pipelines::common::ModelType;
|
|
||||||
use crate::pipelines::token_classification::{
|
|
||||||
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
|
|
||||||
};
|
|
||||||
use crate::resources::{RemoteResource, Resource};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tch::Device;
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use {
|
||||||
|
crate::{
|
||||||
|
mobilebert::{
|
||||||
|
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
|
||||||
|
},
|
||||||
|
pipelines::{common::ModelType, token_classification::LabelAggregationOption},
|
||||||
|
resources::RemoteResource,
|
||||||
|
},
|
||||||
|
tch::Device,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
/// # Part of Speech tag
|
/// # Part of Speech tag
|
||||||
@ -109,19 +113,20 @@ pub struct POSConfig {
|
|||||||
token_classification_config: TokenClassificationConfig,
|
token_classification_config: TokenClassificationConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for POSConfig {
|
impl Default for POSConfig {
|
||||||
/// Provides a Part of speech tagging model (English)
|
/// Provides a Part of speech tagging model (English)
|
||||||
fn default() -> POSConfig {
|
fn default() -> POSConfig {
|
||||||
POSConfig {
|
POSConfig {
|
||||||
token_classification_config: TokenClassificationConfig {
|
token_classification_config: TokenClassificationConfig {
|
||||||
model_type: ModelType::MobileBert,
|
model_type: ModelType::MobileBert,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
|
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
|
||||||
)),
|
)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
|
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
|
||||||
)),
|
)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_ENGLISH_POS,
|
MobileBertVocabResources::MOBILEBERT_ENGLISH_POS,
|
||||||
)),
|
)),
|
||||||
merges_resource: None,
|
merges_resource: None,
|
||||||
|
@ -46,17 +46,14 @@
|
|||||||
use crate::albert::AlbertForQuestionAnswering;
|
use crate::albert::AlbertForQuestionAnswering;
|
||||||
use crate::bert::BertForQuestionAnswering;
|
use crate::bert::BertForQuestionAnswering;
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::deberta::DebertaForQuestionAnswering;
|
use crate::deberta::DebertaForQuestionAnswering;
|
||||||
use crate::distilbert::{
|
use crate::distilbert::DistilBertForQuestionAnswering;
|
||||||
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
|
|
||||||
DistilBertVocabResources,
|
|
||||||
};
|
|
||||||
use crate::fnet::FNetForQuestionAnswering;
|
use crate::fnet::FNetForQuestionAnswering;
|
||||||
use crate::longformer::LongformerForQuestionAnswering;
|
use crate::longformer::LongformerForQuestionAnswering;
|
||||||
use crate::mobilebert::MobileBertForQuestionAnswering;
|
use crate::mobilebert::MobileBertForQuestionAnswering;
|
||||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
use crate::reformer::ReformerForQuestionAnswering;
|
use crate::reformer::ReformerForQuestionAnswering;
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use crate::roberta::RobertaForQuestionAnswering;
|
use crate::roberta::RobertaForQuestionAnswering;
|
||||||
use crate::xlnet::XLNetForQuestionAnswering;
|
use crate::xlnet::XLNetForQuestionAnswering;
|
||||||
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
|
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
|
||||||
@ -70,6 +67,12 @@ use tch::kind::Kind::Float;
|
|||||||
use tch::nn::VarStore;
|
use tch::nn::VarStore;
|
||||||
use tch::{nn, no_grad, Device, Tensor};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
/// # Input for Question Answering
|
/// # Input for Question Answering
|
||||||
/// Includes a context (containing the answer) and question strings
|
/// Includes a context (containing the answer) and question strings
|
||||||
@ -124,13 +127,13 @@ fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
|
|||||||
/// Contains information regarding the model to load and device to place the model on.
|
/// Contains information regarding the model to load and device to place the model on.
|
||||||
pub struct QuestionAnsweringConfig {
|
pub struct QuestionAnsweringConfig {
|
||||||
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
|
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained DistilBERT model on SQuAD)
|
/// Config resource (default: pretrained DistilBERT model on SQuAD)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
|
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: None)
|
/// Merges resource (default: None)
|
||||||
pub merges_resource: Option<Resource>,
|
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||||
/// Device to place the model on (default: CUDA/GPU when available)
|
/// Device to place the model on (default: CUDA/GPU when available)
|
||||||
pub device: Device,
|
pub device: Device,
|
||||||
/// Model type
|
/// Model type
|
||||||
@ -157,27 +160,30 @@ impl QuestionAnsweringConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<R>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
strip_accents: impl Into<Option<bool>>,
|
strip_accents: impl Into<Option<bool>>,
|
||||||
add_prefix_space: impl Into<Option<bool>>,
|
add_prefix_space: impl Into<Option<bool>>,
|
||||||
) -> QuestionAnsweringConfig {
|
) -> QuestionAnsweringConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
QuestionAnsweringConfig {
|
QuestionAnsweringConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||||
lower_case,
|
lower_case,
|
||||||
strip_accents: strip_accents.into(),
|
strip_accents: strip_accents.into(),
|
||||||
add_prefix_space: add_prefix_space.into(),
|
add_prefix_space: add_prefix_space.into(),
|
||||||
@ -194,21 +200,21 @@ impl QuestionAnsweringConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
/// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384.
|
/// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384.
|
||||||
/// * max_query_length - Optional maximum question token length. Defaults to 64.
|
/// * max_query_length - Optional maximum question token length. Defaults to 64.
|
||||||
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
|
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
|
||||||
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
|
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
|
||||||
pub fn custom_new(
|
pub fn custom_new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<R>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
strip_accents: impl Into<Option<bool>>,
|
strip_accents: impl Into<Option<bool>>,
|
||||||
add_prefix_space: impl Into<Option<bool>>,
|
add_prefix_space: impl Into<Option<bool>>,
|
||||||
@ -216,13 +222,16 @@ impl QuestionAnsweringConfig {
|
|||||||
doc_stride: impl Into<Option<usize>>,
|
doc_stride: impl Into<Option<usize>>,
|
||||||
max_query_length: impl Into<Option<usize>>,
|
max_query_length: impl Into<Option<usize>>,
|
||||||
max_answer_length: impl Into<Option<usize>>,
|
max_answer_length: impl Into<Option<usize>>,
|
||||||
) -> QuestionAnsweringConfig {
|
) -> QuestionAnsweringConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
QuestionAnsweringConfig {
|
QuestionAnsweringConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||||
lower_case,
|
lower_case,
|
||||||
strip_accents: strip_accents.into(),
|
strip_accents: strip_accents.into(),
|
||||||
add_prefix_space: add_prefix_space.into(),
|
add_prefix_space: add_prefix_space.into(),
|
||||||
@ -235,16 +244,17 @@ impl QuestionAnsweringConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for QuestionAnsweringConfig {
|
impl Default for QuestionAnsweringConfig {
|
||||||
fn default() -> QuestionAnsweringConfig {
|
fn default() -> QuestionAnsweringConfig {
|
||||||
QuestionAnsweringConfig {
|
QuestionAnsweringConfig {
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||||
)),
|
)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||||
)),
|
)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||||
)),
|
)),
|
||||||
merges_resource: None,
|
merges_resource: None,
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
|
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::{RemoteResource};
|
||||||
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
|
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
|
||||||
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
@ -23,9 +23,9 @@
|
|||||||
//!
|
//!
|
||||||
//! //Load a configuration
|
//! //Load a configuration
|
||||||
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
|
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
|
//! RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
|
//! RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
|
//! RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
|
||||||
//! None, //merges resource only relevant with ModelType::Roberta
|
//! None, //merges resource only relevant with ModelType::Roberta
|
||||||
//! true, //lowercase
|
//! true, //lowercase
|
||||||
//! None, //strip_accents
|
//! None, //strip_accents
|
||||||
@ -61,17 +61,14 @@ use crate::albert::AlbertForSequenceClassification;
|
|||||||
use crate::bart::BartForSequenceClassification;
|
use crate::bart::BartForSequenceClassification;
|
||||||
use crate::bert::BertForSequenceClassification;
|
use crate::bert::BertForSequenceClassification;
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::deberta::DebertaForSequenceClassification;
|
use crate::deberta::DebertaForSequenceClassification;
|
||||||
use crate::distilbert::{
|
use crate::distilbert::DistilBertModelClassifier;
|
||||||
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
|
|
||||||
DistilBertVocabResources,
|
|
||||||
};
|
|
||||||
use crate::fnet::FNetForSequenceClassification;
|
use crate::fnet::FNetForSequenceClassification;
|
||||||
use crate::longformer::LongformerForSequenceClassification;
|
use crate::longformer::LongformerForSequenceClassification;
|
||||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
use crate::reformer::ReformerForSequenceClassification;
|
use crate::reformer::ReformerForSequenceClassification;
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use crate::roberta::RobertaForSequenceClassification;
|
use crate::roberta::RobertaForSequenceClassification;
|
||||||
use crate::xlnet::XLNetForSequenceClassification;
|
use crate::xlnet::XLNetForSequenceClassification;
|
||||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||||
@ -82,6 +79,12 @@ use std::collections::HashMap;
|
|||||||
use tch::nn::VarStore;
|
use tch::nn::VarStore;
|
||||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
/// # Label generated by a `SequenceClassificationModel`
|
/// # Label generated by a `SequenceClassificationModel`
|
||||||
pub struct Label {
|
pub struct Label {
|
||||||
@ -102,13 +105,13 @@ pub struct SequenceClassificationConfig {
|
|||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: None)
|
/// Merges resource (default: None)
|
||||||
pub merges_resource: Option<Resource>,
|
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||||
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
||||||
pub lower_case: bool,
|
pub lower_case: bool,
|
||||||
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
||||||
@ -125,27 +128,30 @@ impl SequenceClassificationConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<R>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
strip_accents: impl Into<Option<bool>>,
|
strip_accents: impl Into<Option<bool>>,
|
||||||
add_prefix_space: impl Into<Option<bool>>,
|
add_prefix_space: impl Into<Option<bool>>,
|
||||||
) -> SequenceClassificationConfig {
|
) -> SequenceClassificationConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
SequenceClassificationConfig {
|
SequenceClassificationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||||
lower_case,
|
lower_case,
|
||||||
strip_accents: strip_accents.into(),
|
strip_accents: strip_accents.into(),
|
||||||
add_prefix_space: add_prefix_space.into(),
|
add_prefix_space: add_prefix_space.into(),
|
||||||
@ -154,26 +160,20 @@ impl SequenceClassificationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for SequenceClassificationConfig {
|
impl Default for SequenceClassificationConfig {
|
||||||
/// Provides a defaultSST-2 sentiment analysis model (English)
|
/// Provides a defaultSST-2 sentiment analysis model (English)
|
||||||
fn default() -> SequenceClassificationConfig {
|
fn default() -> SequenceClassificationConfig {
|
||||||
SequenceClassificationConfig {
|
SequenceClassificationConfig::new(
|
||||||
model_type: ModelType::DistilBert,
|
ModelType::DistilBert,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
|
||||||
DistilBertModelResources::DISTIL_BERT_SST2,
|
RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
|
||||||
)),
|
RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
None,
|
||||||
DistilBertConfigResources::DISTIL_BERT_SST2,
|
true,
|
||||||
)),
|
None,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
None,
|
||||||
DistilBertVocabResources::DISTIL_BERT_SST2,
|
)
|
||||||
)),
|
|
||||||
merges_resource: None,
|
|
||||||
lower_case: true,
|
|
||||||
strip_accents: None,
|
|
||||||
add_prefix_space: None,
|
|
||||||
device: Device::cuda_if_available(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,17 +64,21 @@
|
|||||||
|
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::bart::{
|
use crate::bart::BartGenerator;
|
||||||
BartConfigResources, BartGenerator, BartMergesResources, BartModelResources, BartVocabResources,
|
|
||||||
};
|
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::pegasus::PegasusConditionalGenerator;
|
use crate::pegasus::PegasusConditionalGenerator;
|
||||||
use crate::pipelines::common::ModelType;
|
use crate::pipelines::common::ModelType;
|
||||||
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||||
use crate::prophetnet::ProphetNetConditionalGenerator;
|
use crate::prophetnet::ProphetNetConditionalGenerator;
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use crate::t5::T5Generator;
|
use crate::t5::T5Generator;
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
/// # Configuration for text summarization
|
/// # Configuration for text summarization
|
||||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||||
/// different set of default parameters and sets the device to place the model on.
|
/// different set of default parameters and sets the device to place the model on.
|
||||||
@ -82,13 +86,13 @@ pub struct SummarizationConfig {
|
|||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained BART model on CNN-DM)
|
/// Config resource (default: pretrained BART model on CNN-DM)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: pretrained BART model on CNN-DM)
|
/// Merges resource (default: pretrained BART model on CNN-DM)
|
||||||
pub merges_resource: Resource,
|
pub merges_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Minimum sequence length (default: 0)
|
/// Minimum sequence length (default: 0)
|
||||||
pub min_length: i64,
|
pub min_length: i64,
|
||||||
/// Maximum sequence length (default: 20)
|
/// Maximum sequence length (default: 20)
|
||||||
@ -127,45 +131,26 @@ impl SummarizationConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Resource,
|
merges_resource: R,
|
||||||
) -> SummarizationConfig {
|
) -> SummarizationConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
SummarizationConfig {
|
SummarizationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: Box::new(merges_resource),
|
||||||
device: Device::cuda_if_available(),
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SummarizationConfig {
|
|
||||||
fn default() -> SummarizationConfig {
|
|
||||||
SummarizationConfig {
|
|
||||||
model_type: ModelType::Bart,
|
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BartModelResources::BART_CNN,
|
|
||||||
)),
|
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BartConfigResources::BART_CNN,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BartVocabResources::BART_CNN,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
BartMergesResources::BART_CNN,
|
|
||||||
)),
|
|
||||||
min_length: 56,
|
min_length: 56,
|
||||||
max_length: 142,
|
max_length: 142,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
@ -185,6 +170,19 @@ impl Default for SummarizationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
impl Default for SummarizationConfig {
|
||||||
|
fn default() -> SummarizationConfig {
|
||||||
|
SummarizationConfig::new(
|
||||||
|
ModelType::Bart,
|
||||||
|
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
|
||||||
|
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
|
||||||
|
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
|
||||||
|
RemoteResource::from_pretrained(BartMergesResources::BART_CNN),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<SummarizationConfig> for GenerateConfig {
|
impl From<SummarizationConfig> for GenerateConfig {
|
||||||
fn from(config: SummarizationConfig) -> GenerateConfig {
|
fn from(config: SummarizationConfig) -> GenerateConfig {
|
||||||
GenerateConfig {
|
GenerateConfig {
|
||||||
|
@ -34,19 +34,22 @@
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::RemoteResource;
|
use crate::gpt2::GPT2Generator;
|
||||||
use crate::gpt2::{
|
|
||||||
GPT2Generator, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
|
||||||
};
|
|
||||||
use crate::gpt_neo::GptNeoGenerator;
|
use crate::gpt_neo::GptNeoGenerator;
|
||||||
use crate::openai_gpt::OpenAIGenerator;
|
use crate::openai_gpt::OpenAIGenerator;
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||||
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
||||||
use crate::reformer::ReformerGenerator;
|
use crate::reformer::ReformerGenerator;
|
||||||
use crate::resources::Resource;
|
use crate::resources::ResourceProvider;
|
||||||
use crate::xlnet::XLNetGenerator;
|
use crate::xlnet::XLNetGenerator;
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
/// # Configuration for text generation
|
/// # Configuration for text generation
|
||||||
/// Contains information regarding the model to load, mirrors the GenerateConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerateConfig, with a
|
||||||
/// different set of default parameters and sets the device to place the model on.
|
/// different set of default parameters and sets the device to place the model on.
|
||||||
@ -54,13 +57,13 @@ pub struct TextGenerationConfig {
|
|||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained BART model on CNN-DM)
|
/// Config resource (default: pretrained BART model on CNN-DM)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: pretrained BART model on CNN-DM)
|
/// Merges resource (default: pretrained BART model on CNN-DM)
|
||||||
pub merges_resource: Resource,
|
pub merges_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Minimum sequence length (default: 0)
|
/// Minimum sequence length (default: 0)
|
||||||
pub min_length: i64,
|
pub min_length: i64,
|
||||||
/// Maximum sequence length (default: 20)
|
/// Maximum sequence length (default: 20)
|
||||||
@ -99,45 +102,26 @@ impl TextGenerationConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * merges_resource - The `Resource` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Resource,
|
merges_resource: R,
|
||||||
) -> TextGenerationConfig {
|
) -> TextGenerationConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
TextGenerationConfig {
|
TextGenerationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: Box::new(merges_resource),
|
||||||
device: Device::cuda_if_available(),
|
|
||||||
..Default::default()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TextGenerationConfig {
|
|
||||||
fn default() -> TextGenerationConfig {
|
|
||||||
TextGenerationConfig {
|
|
||||||
model_type: ModelType::GPT2,
|
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2ModelResources::GPT2_MEDIUM,
|
|
||||||
)),
|
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2ConfigResources::GPT2_MEDIUM,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2VocabResources::GPT2_MEDIUM,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
Gpt2MergesResources::GPT2_MEDIUM,
|
|
||||||
)),
|
|
||||||
min_length: 0,
|
min_length: 0,
|
||||||
max_length: 20,
|
max_length: 20,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
@ -157,6 +141,19 @@ impl Default for TextGenerationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
impl Default for TextGenerationConfig {
|
||||||
|
fn default() -> TextGenerationConfig {
|
||||||
|
TextGenerationConfig::new(
|
||||||
|
ModelType::GPT2,
|
||||||
|
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
|
||||||
|
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
|
||||||
|
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
|
||||||
|
RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<TextGenerationConfig> for GenerateConfig {
|
impl From<TextGenerationConfig> for GenerateConfig {
|
||||||
fn from(config: TextGenerationConfig) -> GenerateConfig {
|
fn from(config: TextGenerationConfig) -> GenerateConfig {
|
||||||
GenerateConfig {
|
GenerateConfig {
|
||||||
|
@ -16,17 +16,18 @@
|
|||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
|
//! use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
|
||||||
//! use rust_bert::resources::{Resource,RemoteResource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//! # fn main() -> anyhow::Result<()> {
|
//! # fn main() -> anyhow::Result<()> {
|
||||||
//!
|
//!
|
||||||
//! //Load a configuration
|
//! //Load a configuration
|
||||||
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
||||||
//! let config = TokenClassificationConfig::new(ModelType::Bert,
|
//! let config = TokenClassificationConfig::new(
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
//! ModelType::Bert,
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
//! RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||||
//! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
//! RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||||
|
//! RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||||
//! None, //merges resource only relevant with ModelType::Roberta
|
//! None, //merges resource only relevant with ModelType::Roberta
|
||||||
//! false, //lowercase
|
//! false, //lowercase
|
||||||
//! None, //strip_accents
|
//! None, //strip_accents
|
||||||
@ -111,11 +112,8 @@
|
|||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::albert::AlbertForTokenClassification;
|
use crate::albert::AlbertForTokenClassification;
|
||||||
use crate::bert::{
|
use crate::bert::BertForTokenClassification;
|
||||||
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
|
|
||||||
};
|
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::deberta::DebertaForTokenClassification;
|
use crate::deberta::DebertaForTokenClassification;
|
||||||
use crate::distilbert::DistilBertForTokenClassification;
|
use crate::distilbert::DistilBertForTokenClassification;
|
||||||
use crate::electra::ElectraForTokenClassification;
|
use crate::electra::ElectraForTokenClassification;
|
||||||
@ -123,6 +121,7 @@ use crate::fnet::FNetForTokenClassification;
|
|||||||
use crate::longformer::LongformerForTokenClassification;
|
use crate::longformer::LongformerForTokenClassification;
|
||||||
use crate::mobilebert::MobileBertForTokenClassification;
|
use crate::mobilebert::MobileBertForTokenClassification;
|
||||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use crate::roberta::RobertaForTokenClassification;
|
use crate::roberta::RobertaForTokenClassification;
|
||||||
use crate::xlnet::XLNetForTokenClassification;
|
use crate::xlnet::XLNetForTokenClassification;
|
||||||
use rust_tokenizers::tokenizer::Tokenizer;
|
use rust_tokenizers::tokenizer::Tokenizer;
|
||||||
@ -137,6 +136,12 @@ use std::collections::HashMap;
|
|||||||
use tch::nn::VarStore;
|
use tch::nn::VarStore;
|
||||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
bert::{BertConfigResources, BertModelResources, BertVocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
/// # Token generated by a `TokenClassificationModel`
|
/// # Token generated by a `TokenClassificationModel`
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
@ -215,13 +220,13 @@ pub struct TokenClassificationConfig {
|
|||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: pretrained BERT model on CoNLL)
|
/// Merges resource (default: pretrained BERT model on CoNLL)
|
||||||
pub merges_resource: Option<Resource>,
|
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||||
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
||||||
pub lower_case: bool,
|
pub lower_case: bool,
|
||||||
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
||||||
@ -242,28 +247,31 @@ impl TokenClassificationConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab - The `Resource' pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<R>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
strip_accents: impl Into<Option<bool>>,
|
strip_accents: impl Into<Option<bool>>,
|
||||||
add_prefix_space: impl Into<Option<bool>>,
|
add_prefix_space: impl Into<Option<bool>>,
|
||||||
label_aggregation_function: LabelAggregationOption,
|
label_aggregation_function: LabelAggregationOption,
|
||||||
) -> TokenClassificationConfig {
|
) -> TokenClassificationConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
TokenClassificationConfig {
|
TokenClassificationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||||
lower_case,
|
lower_case,
|
||||||
strip_accents: strip_accents.into(),
|
strip_accents: strip_accents.into(),
|
||||||
add_prefix_space: add_prefix_space.into(),
|
add_prefix_space: add_prefix_space.into(),
|
||||||
@ -274,28 +282,21 @@ impl TokenClassificationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for TokenClassificationConfig {
|
impl Default for TokenClassificationConfig {
|
||||||
/// Provides a default CoNLL-2003 NER model (English)
|
/// Provides a default CoNLL-2003 NER model (English)
|
||||||
fn default() -> TokenClassificationConfig {
|
fn default() -> TokenClassificationConfig {
|
||||||
TokenClassificationConfig {
|
TokenClassificationConfig::new(
|
||||||
model_type: ModelType::Bert,
|
ModelType::Bert,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||||
BertModelResources::BERT_NER,
|
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||||
)),
|
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
None,
|
||||||
BertConfigResources::BERT_NER,
|
false,
|
||||||
)),
|
None,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
None,
|
||||||
BertVocabResources::BERT_NER,
|
LabelAggregationOption::First,
|
||||||
)),
|
)
|
||||||
merges_resource: None,
|
|
||||||
lower_case: false,
|
|
||||||
strip_accents: None,
|
|
||||||
add_prefix_space: None,
|
|
||||||
device: Device::cuda_if_available(),
|
|
||||||
label_aggregation_function: LabelAggregationOption::First,
|
|
||||||
batch_size: 64,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,22 +21,14 @@
|
|||||||
//! };
|
//! };
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use tch::Device;
|
//! use tch::Device;
|
||||||
//!
|
//!
|
||||||
//! fn main() -> anyhow::Result<()> {
|
//! fn main() -> anyhow::Result<()> {
|
||||||
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
|
||||||
//! M2M100ModelResources::M2M100_418M,
|
//! let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
|
||||||
//! ));
|
//! let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
|
||||||
//! M2M100ConfigResources::M2M100_418M,
|
|
||||||
//! ));
|
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
//! M2M100VocabResources::M2M100_418M,
|
|
||||||
//! ));
|
|
||||||
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
//! M2M100MergesResources::M2M100_418M,
|
|
||||||
//! ));
|
|
||||||
//!
|
//!
|
||||||
//! let source_languages = M2M100SourceLanguages::M2M100_418M;
|
//! let source_languages = M2M100SourceLanguages::M2M100_418M;
|
||||||
//! let target_languages = M2M100TargetLanguages::M2M100_418M;
|
//! let target_languages = M2M100TargetLanguages::M2M100_418M;
|
||||||
|
@ -1,31 +1,14 @@
|
|||||||
use crate::m2m_100::{
|
|
||||||
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
|
|
||||||
M2M100TargetLanguages, M2M100VocabResources,
|
|
||||||
};
|
|
||||||
use crate::marian::{
|
|
||||||
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
|
||||||
MarianTargetLanguages, MarianVocabResources,
|
|
||||||
};
|
|
||||||
use crate::mbart::{
|
|
||||||
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
|
|
||||||
MBartVocabResources,
|
|
||||||
};
|
|
||||||
use crate::pipelines::common::ModelType;
|
use crate::pipelines::common::ModelType;
|
||||||
use crate::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use crate::pipelines::translation::Language;
|
||||||
use crate::resources::{RemoteResource, Resource};
|
|
||||||
use crate::RustBertError;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
struct TranslationResources {
|
#[cfg(feature = "remote")]
|
||||||
model_type: ModelType,
|
use crate::{
|
||||||
model_resource: Resource,
|
pipelines::translation::{TranslationConfig, TranslationModel},
|
||||||
config_resource: Resource,
|
resources::ResourceProvider,
|
||||||
vocab_resource: Resource,
|
RustBertError,
|
||||||
merges_resource: Resource,
|
};
|
||||||
source_languages: Vec<Language>,
|
|
||||||
target_languages: Vec<Language>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq)]
|
#[derive(Clone, Copy, PartialEq)]
|
||||||
enum ModelSize {
|
enum ModelSize {
|
||||||
@ -86,21 +69,6 @@ pub struct TranslationModelBuilder {
|
|||||||
model_size: Option<ModelSize>,
|
model_size: Option<ModelSize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! get_marian_resources {
|
|
||||||
($name:ident) => {
|
|
||||||
(
|
|
||||||
(
|
|
||||||
MarianModelResources::$name,
|
|
||||||
MarianConfigResources::$name,
|
|
||||||
MarianVocabResources::$name,
|
|
||||||
MarianSpmResources::$name,
|
|
||||||
),
|
|
||||||
MarianSourceLanguages::$name.iter().cloned().collect(),
|
|
||||||
MarianTargetLanguages::$name.iter().cloned().collect(),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TranslationModelBuilder {
|
impl Default for TranslationModelBuilder {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
TranslationModelBuilder::new()
|
TranslationModelBuilder::new()
|
||||||
@ -335,29 +303,162 @@ impl TranslationModelBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_default_model(
|
/// Creates the translation model based on the specifications provided
|
||||||
&self,
|
///
|
||||||
source_languages: Option<&Vec<Language>>,
|
/// # Returns
|
||||||
target_languages: Option<&Vec<Language>>,
|
/// * `TranslationModel` Generated translation model
|
||||||
) -> Result<TranslationResources, RustBertError> {
|
///
|
||||||
Ok(
|
/// # Example
|
||||||
match self.get_marian_model(source_languages, target_languages) {
|
///
|
||||||
Ok(marian_resources) => marian_resources,
|
/// ```no_run
|
||||||
Err(_) => match self.model_size {
|
/// use rust_bert::pipelines::translation::Language;
|
||||||
|
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
|
||||||
|
/// fn main() -> anyhow::Result<()> {
|
||||||
|
/// let model = TranslationModelBuilder::new()
|
||||||
|
/// .with_target_languages([
|
||||||
|
/// Language::Japanese,
|
||||||
|
/// Language::Korean,
|
||||||
|
/// Language::ChineseMandarin,
|
||||||
|
/// ])
|
||||||
|
/// .create_model();
|
||||||
|
/// Ok(())
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
|
||||||
|
let device = self.device.unwrap_or_else(Device::cuda_if_available);
|
||||||
|
|
||||||
|
let translation_resources = match (
|
||||||
|
&self.model_type,
|
||||||
|
&self.source_languages,
|
||||||
|
&self.target_languages,
|
||||||
|
) {
|
||||||
|
(Some(ModelType::M2M100), source_languages, target_languages) => {
|
||||||
|
match self.model_size {
|
||||||
Some(value) if value == ModelSize::XLarge => {
|
Some(value) if value == ModelSize::XLarge => {
|
||||||
self.get_m2m100_xlarge_resources(source_languages, target_languages)?
|
model_fetchers::get_m2m100_xlarge_resources(
|
||||||
|
source_languages.as_ref(),
|
||||||
|
target_languages.as_ref(),
|
||||||
|
)?
|
||||||
}
|
}
|
||||||
_ => self.get_m2m100_large_resources(source_languages, target_languages)?,
|
_ => model_fetchers::get_m2m100_large_resources(
|
||||||
},
|
source_languages.as_ref(),
|
||||||
},
|
target_languages.as_ref(),
|
||||||
)
|
)?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Some(ModelType::MBart), source_languages, target_languages) => {
|
||||||
|
model_fetchers::get_mbart50_resources(
|
||||||
|
source_languages.as_ref(),
|
||||||
|
target_languages.as_ref(),
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
(Some(ModelType::Marian), source_languages, target_languages) => {
|
||||||
|
model_fetchers::get_marian_model(
|
||||||
|
source_languages.as_ref(),
|
||||||
|
target_languages.as_ref(),
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
(None, source_languages, target_languages) => model_fetchers::get_default_model(
|
||||||
|
&self.model_size,
|
||||||
|
source_languages.as_ref(),
|
||||||
|
target_languages.as_ref(),
|
||||||
|
)?,
|
||||||
|
(_, None, None) | (_, _, None) | (_, None, _) => {
|
||||||
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
|
"Source and target languages must be specified for {:?}",
|
||||||
|
self.model_type.unwrap()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
(Some(model_type), _, _) => {
|
||||||
|
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||||
|
"Automated translation model builder not implemented for {:?}",
|
||||||
|
model_type
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let translation_config = TranslationConfig::new(
|
||||||
|
translation_resources.model_type,
|
||||||
|
translation_resources.model_resource,
|
||||||
|
translation_resources.config_resource,
|
||||||
|
translation_resources.vocab_resource,
|
||||||
|
translation_resources.merges_resource,
|
||||||
|
translation_resources.source_languages,
|
||||||
|
translation_resources.target_languages,
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
TranslationModel::new(translation_config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
mod model_fetchers {
|
||||||
|
use super::*;
|
||||||
|
use crate::{
|
||||||
|
m2m_100::{
|
||||||
|
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources,
|
||||||
|
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
|
||||||
|
},
|
||||||
|
marian::{
|
||||||
|
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||||
|
MarianTargetLanguages, MarianVocabResources,
|
||||||
|
},
|
||||||
|
mbart::{
|
||||||
|
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
|
||||||
|
MBartVocabResources,
|
||||||
|
},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub(super) struct TranslationResources<R>
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
|
pub(super) model_type: ModelType,
|
||||||
|
pub(super) model_resource: R,
|
||||||
|
pub(super) config_resource: R,
|
||||||
|
pub(super) vocab_resource: R,
|
||||||
|
pub(super) merges_resource: R,
|
||||||
|
pub(super) source_languages: Vec<Language>,
|
||||||
|
pub(super) target_languages: Vec<Language>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_marian_model(
|
macro_rules! get_marian_resources {
|
||||||
&self,
|
($name:ident) => {
|
||||||
|
(
|
||||||
|
(
|
||||||
|
MarianModelResources::$name,
|
||||||
|
MarianConfigResources::$name,
|
||||||
|
MarianVocabResources::$name,
|
||||||
|
MarianSpmResources::$name,
|
||||||
|
),
|
||||||
|
MarianSourceLanguages::$name.iter().cloned().collect(),
|
||||||
|
MarianTargetLanguages::$name.iter().cloned().collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn get_default_model(
|
||||||
|
model_size: &Option<ModelSize>,
|
||||||
source_languages: Option<&Vec<Language>>,
|
source_languages: Option<&Vec<Language>>,
|
||||||
target_languages: Option<&Vec<Language>>,
|
target_languages: Option<&Vec<Language>>,
|
||||||
) -> Result<TranslationResources, RustBertError> {
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
|
Ok(match get_marian_model(source_languages, target_languages) {
|
||||||
|
Ok(marian_resources) => marian_resources,
|
||||||
|
Err(_) => match model_size {
|
||||||
|
Some(value) if value == &ModelSize::XLarge => {
|
||||||
|
get_m2m100_xlarge_resources(source_languages, target_languages)?
|
||||||
|
}
|
||||||
|
_ => get_m2m100_large_resources(source_languages, target_languages)?,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn get_marian_model(
|
||||||
|
source_languages: Option<&Vec<Language>>,
|
||||||
|
target_languages: Option<&Vec<Language>>,
|
||||||
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
let (resources, source_languages, target_languages) =
|
let (resources, source_languages, target_languages) =
|
||||||
if let (Some(source_languages), Some(target_languages)) =
|
if let (Some(source_languages), Some(target_languages)) =
|
||||||
(source_languages, target_languages)
|
(source_languages, target_languages)
|
||||||
@ -446,20 +547,19 @@ impl TranslationModelBuilder {
|
|||||||
|
|
||||||
Ok(TranslationResources {
|
Ok(TranslationResources {
|
||||||
model_type: ModelType::Marian,
|
model_type: ModelType::Marian,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(resources.0)),
|
model_resource: RemoteResource::from_pretrained(resources.0),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(resources.1)),
|
config_resource: RemoteResource::from_pretrained(resources.1),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(resources.2)),
|
vocab_resource: RemoteResource::from_pretrained(resources.2),
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(resources.3)),
|
merges_resource: RemoteResource::from_pretrained(resources.3),
|
||||||
source_languages,
|
source_languages,
|
||||||
target_languages,
|
target_languages,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mbart50_resources(
|
pub(super) fn get_mbart50_resources(
|
||||||
&self,
|
|
||||||
source_languages: Option<&Vec<Language>>,
|
source_languages: Option<&Vec<Language>>,
|
||||||
target_languages: Option<&Vec<Language>>,
|
target_languages: Option<&Vec<Language>>,
|
||||||
) -> Result<TranslationResources, RustBertError> {
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
if let Some(source_languages) = source_languages {
|
if let Some(source_languages) = source_languages {
|
||||||
if !source_languages
|
if !source_languages
|
||||||
.iter()
|
.iter()
|
||||||
@ -488,28 +588,27 @@ impl TranslationModelBuilder {
|
|||||||
|
|
||||||
Ok(TranslationResources {
|
Ok(TranslationResources {
|
||||||
model_type: ModelType::MBart,
|
model_type: ModelType::MBart,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: RemoteResource::from_pretrained(
|
||||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||||
)),
|
),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: RemoteResource::from_pretrained(
|
||||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||||
)),
|
),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: RemoteResource::from_pretrained(
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||||
)),
|
),
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: RemoteResource::from_pretrained(
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||||
)),
|
),
|
||||||
source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY.to_vec(),
|
source_languages: MBartSourceLanguages::MBART50_MANY_TO_MANY.to_vec(),
|
||||||
target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY.to_vec(),
|
target_languages: MBartTargetLanguages::MBART50_MANY_TO_MANY.to_vec(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_m2m100_large_resources(
|
pub(super) fn get_m2m100_large_resources(
|
||||||
&self,
|
|
||||||
source_languages: Option<&Vec<Language>>,
|
source_languages: Option<&Vec<Language>>,
|
||||||
target_languages: Option<&Vec<Language>>,
|
target_languages: Option<&Vec<Language>>,
|
||||||
) -> Result<TranslationResources, RustBertError> {
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
if let Some(source_languages) = source_languages {
|
if let Some(source_languages) = source_languages {
|
||||||
if !source_languages
|
if !source_languages
|
||||||
.iter()
|
.iter()
|
||||||
@ -538,28 +637,19 @@ impl TranslationModelBuilder {
|
|||||||
|
|
||||||
Ok(TranslationResources {
|
Ok(TranslationResources {
|
||||||
model_type: ModelType::M2M100,
|
model_type: ModelType::M2M100,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M),
|
||||||
M2M100ModelResources::M2M100_418M,
|
config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M),
|
||||||
)),
|
vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M),
|
||||||
M2M100ConfigResources::M2M100_418M,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100VocabResources::M2M100_418M,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_418M,
|
|
||||||
)),
|
|
||||||
source_languages: M2M100SourceLanguages::M2M100_418M.to_vec(),
|
source_languages: M2M100SourceLanguages::M2M100_418M.to_vec(),
|
||||||
target_languages: M2M100TargetLanguages::M2M100_418M.to_vec(),
|
target_languages: M2M100TargetLanguages::M2M100_418M.to_vec(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_m2m100_xlarge_resources(
|
pub(super) fn get_m2m100_xlarge_resources(
|
||||||
&self,
|
|
||||||
source_languages: Option<&Vec<Language>>,
|
source_languages: Option<&Vec<Language>>,
|
||||||
target_languages: Option<&Vec<Language>>,
|
target_languages: Option<&Vec<Language>>,
|
||||||
) -> Result<TranslationResources, RustBertError> {
|
) -> Result<TranslationResources<RemoteResource>, RustBertError> {
|
||||||
if let Some(source_languages) = source_languages {
|
if let Some(source_languages) = source_languages {
|
||||||
if !source_languages
|
if !source_languages
|
||||||
.iter()
|
.iter()
|
||||||
@ -588,97 +678,12 @@ impl TranslationModelBuilder {
|
|||||||
|
|
||||||
Ok(TranslationResources {
|
Ok(TranslationResources {
|
||||||
model_type: ModelType::M2M100,
|
model_type: ModelType::M2M100,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: RemoteResource::from_pretrained(M2M100ModelResources::M2M100_1_2B),
|
||||||
M2M100ModelResources::M2M100_1_2B,
|
config_resource: RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_1_2B),
|
||||||
)),
|
vocab_resource: RemoteResource::from_pretrained(M2M100VocabResources::M2M100_1_2B),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: RemoteResource::from_pretrained(M2M100MergesResources::M2M100_1_2B),
|
||||||
M2M100ConfigResources::M2M100_1_2B,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100VocabResources::M2M100_1_2B,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_1_2B,
|
|
||||||
)),
|
|
||||||
source_languages: M2M100SourceLanguages::M2M100_1_2B.to_vec(),
|
source_languages: M2M100SourceLanguages::M2M100_1_2B.to_vec(),
|
||||||
target_languages: M2M100TargetLanguages::M2M100_1_2B.to_vec(),
|
target_languages: M2M100TargetLanguages::M2M100_1_2B.to_vec(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates the translation model based on the specifications provided
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// * `TranslationModel` Generated translation model
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use rust_bert::pipelines::translation::Language;
|
|
||||||
/// use rust_bert::pipelines::translation::TranslationModelBuilder;
|
|
||||||
/// fn main() -> anyhow::Result<()> {
|
|
||||||
/// let model = TranslationModelBuilder::new()
|
|
||||||
/// .with_target_languages([
|
|
||||||
/// Language::Japanese,
|
|
||||||
/// Language::Korean,
|
|
||||||
/// Language::ChineseMandarin,
|
|
||||||
/// ])
|
|
||||||
/// .create_model();
|
|
||||||
/// Ok(())
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn create_model(&self) -> Result<TranslationModel, RustBertError> {
|
|
||||||
let device = self.device.unwrap_or_else(Device::cuda_if_available);
|
|
||||||
|
|
||||||
let translation_resources = match (
|
|
||||||
&self.model_type,
|
|
||||||
&self.source_languages,
|
|
||||||
&self.target_languages,
|
|
||||||
) {
|
|
||||||
(Some(ModelType::M2M100), source_languages, target_languages) => {
|
|
||||||
match self.model_size {
|
|
||||||
Some(value) if value == ModelSize::XLarge => self.get_m2m100_xlarge_resources(
|
|
||||||
source_languages.as_ref(),
|
|
||||||
target_languages.as_ref(),
|
|
||||||
)?,
|
|
||||||
_ => self.get_m2m100_large_resources(
|
|
||||||
source_languages.as_ref(),
|
|
||||||
target_languages.as_ref(),
|
|
||||||
)?,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(Some(ModelType::MBart), source_languages, target_languages) => {
|
|
||||||
self.get_mbart50_resources(source_languages.as_ref(), target_languages.as_ref())?
|
|
||||||
}
|
|
||||||
(Some(ModelType::Marian), source_languages, target_languages) => {
|
|
||||||
self.get_marian_model(source_languages.as_ref(), target_languages.as_ref())?
|
|
||||||
}
|
|
||||||
(None, source_languages, target_languages) => {
|
|
||||||
self.get_default_model(source_languages.as_ref(), target_languages.as_ref())?
|
|
||||||
}
|
|
||||||
(_, None, None) | (_, _, None) | (_, None, _) => {
|
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
|
||||||
"Source and target languages must be specified for {:?}",
|
|
||||||
self.model_type.unwrap()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
(Some(model_type), _, _) => {
|
|
||||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
|
||||||
"Automated translation model builder not implemented for {:?}",
|
|
||||||
model_type
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let translation_config = TranslationConfig::new(
|
|
||||||
translation_resources.model_type,
|
|
||||||
translation_resources.model_resource,
|
|
||||||
translation_resources.config_resource,
|
|
||||||
translation_resources.vocab_resource,
|
|
||||||
translation_resources.merges_resource,
|
|
||||||
translation_resources.source_languages,
|
|
||||||
translation_resources.target_languages,
|
|
||||||
device,
|
|
||||||
);
|
|
||||||
TranslationModel::new(translation_config)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -14,13 +14,13 @@
|
|||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
use crate::common::error::RustBertError;
|
use crate::common::error::RustBertError;
|
||||||
use crate::common::resources::Resource;
|
|
||||||
use crate::m2m_100::M2M100Generator;
|
use crate::m2m_100::M2M100Generator;
|
||||||
use crate::marian::MarianGenerator;
|
use crate::marian::MarianGenerator;
|
||||||
use crate::mbart::MBartGenerator;
|
use crate::mbart::MBartGenerator;
|
||||||
use crate::pipelines::common::ModelType;
|
use crate::pipelines::common::ModelType;
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||||
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
||||||
|
use crate::resources::ResourceProvider;
|
||||||
use crate::t5::T5Generator;
|
use crate::t5::T5Generator;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@ -374,13 +374,13 @@ pub struct TranslationConfig {
|
|||||||
/// Model type used for translation
|
/// Model type used for translation
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource
|
/// Model weights resource
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource
|
/// Config resource
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource
|
/// Vocab resource
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource
|
/// Merges resource
|
||||||
pub merges_resource: Resource,
|
pub merges_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Supported source languages
|
/// Supported source languages
|
||||||
pub source_languages: HashSet<Language>,
|
pub source_languages: HashSet<Language>,
|
||||||
/// Supported target languages
|
/// Supported target languages
|
||||||
@ -435,18 +435,18 @@ impl TranslationConfig {
|
|||||||
/// };
|
/// };
|
||||||
/// use rust_bert::pipelines::common::ModelType;
|
/// use rust_bert::pipelines::common::ModelType;
|
||||||
/// use rust_bert::pipelines::translation::TranslationConfig;
|
/// use rust_bert::pipelines::translation::TranslationConfig;
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
/// use rust_bert::resources::RemoteResource;
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let model_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let config_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianConfigResources::ROMANCE2ENGLISH,
|
/// MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let vocab_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianVocabResources::ROMANCE2ENGLISH,
|
/// MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
///
|
///
|
||||||
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
||||||
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
||||||
@ -464,17 +464,18 @@ impl TranslationConfig {
|
|||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new<S, T>(
|
pub fn new<R, S, T>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Resource,
|
merges_resource: R,
|
||||||
source_languages: S,
|
source_languages: S,
|
||||||
target_languages: T,
|
target_languages: T,
|
||||||
device: impl Into<Option<Device>>,
|
device: impl Into<Option<Device>>,
|
||||||
) -> TranslationConfig
|
) -> TranslationConfig
|
||||||
where
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
S: AsRef<[Language]>,
|
S: AsRef<[Language]>,
|
||||||
T: AsRef<[Language]>,
|
T: AsRef<[Language]>,
|
||||||
{
|
{
|
||||||
@ -482,10 +483,10 @@ impl TranslationConfig {
|
|||||||
|
|
||||||
TranslationConfig {
|
TranslationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: Box::new(merges_resource),
|
||||||
source_languages: source_languages.as_ref().iter().cloned().collect(),
|
source_languages: source_languages.as_ref().iter().cloned().collect(),
|
||||||
target_languages: target_languages.as_ref().iter().cloned().collect(),
|
target_languages: target_languages.as_ref().iter().cloned().collect(),
|
||||||
device,
|
device,
|
||||||
@ -798,18 +799,18 @@ impl TranslationModel {
|
|||||||
/// };
|
/// };
|
||||||
/// use rust_bert::pipelines::common::ModelType;
|
/// use rust_bert::pipelines::common::ModelType;
|
||||||
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
/// use rust_bert::resources::RemoteResource;
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let model_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let config_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianConfigResources::ROMANCE2ENGLISH,
|
/// MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let vocab_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianVocabResources::ROMANCE2ENGLISH,
|
/// MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
/// ));
|
/// );
|
||||||
///
|
///
|
||||||
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
||||||
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
||||||
@ -859,21 +860,21 @@ impl TranslationModel {
|
|||||||
/// };
|
/// };
|
||||||
/// use rust_bert::pipelines::common::ModelType;
|
/// use rust_bert::pipelines::common::ModelType;
|
||||||
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
/// use rust_bert::resources::RemoteResource;
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let model_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianModelResources::ENGLISH2ROMANCE,
|
/// MarianModelResources::ENGLISH2ROMANCE,
|
||||||
/// ));
|
/// );
|
||||||
/// let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let config_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianConfigResources::ENGLISH2ROMANCE,
|
/// MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
/// ));
|
/// );
|
||||||
/// let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let vocab_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianVocabResources::ENGLISH2ROMANCE,
|
/// MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
/// ));
|
/// );
|
||||||
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let merges_resource = RemoteResource::from_pretrained(
|
||||||
/// MarianSpmResources::ENGLISH2ROMANCE,
|
/// MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
/// ));
|
/// );
|
||||||
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
|
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
|
||||||
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
|
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
|
||||||
///
|
///
|
||||||
@ -938,15 +939,10 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
#[ignore] // no need to run, compilation is enough to verify it is Send
|
#[ignore] // no need to run, compilation is enough to verify it is Send
|
||||||
fn test() {
|
fn test() {
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
|
||||||
MarianModelResources::ROMANCE2ENGLISH,
|
let config_resource =
|
||||||
));
|
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
|
||||||
MarianConfigResources::ROMANCE2ENGLISH,
|
|
||||||
));
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MarianVocabResources::ROMANCE2ENGLISH,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
|
||||||
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
|
||||||
|
@ -99,10 +99,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::albert::AlbertForSequenceClassification;
|
use crate::albert::AlbertForSequenceClassification;
|
||||||
use crate::bart::{
|
use crate::bart::BartForSequenceClassification;
|
||||||
BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources,
|
|
||||||
BartVocabResources,
|
|
||||||
};
|
|
||||||
use crate::bert::BertForSequenceClassification;
|
use crate::bert::BertForSequenceClassification;
|
||||||
use crate::deberta::DebertaForSequenceClassification;
|
use crate::deberta::DebertaForSequenceClassification;
|
||||||
use crate::distilbert::DistilBertModelClassifier;
|
use crate::distilbert::DistilBertModelClassifier;
|
||||||
@ -110,7 +107,7 @@ use crate::longformer::LongformerForSequenceClassification;
|
|||||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
use crate::pipelines::sequence_classification::Label;
|
use crate::pipelines::sequence_classification::Label;
|
||||||
use crate::resources::{RemoteResource, Resource};
|
use crate::resources::ResourceProvider;
|
||||||
use crate::roberta::RobertaForSequenceClassification;
|
use crate::roberta::RobertaForSequenceClassification;
|
||||||
use crate::xlnet::XLNetForSequenceClassification;
|
use crate::xlnet::XLNetForSequenceClassification;
|
||||||
use crate::RustBertError;
|
use crate::RustBertError;
|
||||||
@ -122,19 +119,25 @@ use tch::kind::Kind::{Bool, Float};
|
|||||||
use tch::nn::VarStore;
|
use tch::nn::VarStore;
|
||||||
use tch::{nn, no_grad, Device, Tensor};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
use crate::{
|
||||||
|
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
|
||||||
|
resources::RemoteResource,
|
||||||
|
};
|
||||||
|
|
||||||
/// # Configuration for ZeroShotClassificationModel
|
/// # Configuration for ZeroShotClassificationModel
|
||||||
/// Contains information regarding the model to load and device to place the model on.
|
/// Contains information regarding the model to load and device to place the model on.
|
||||||
pub struct ZeroShotClassificationConfig {
|
pub struct ZeroShotClassificationConfig {
|
||||||
/// Model type
|
/// Model type
|
||||||
pub model_type: ModelType,
|
pub model_type: ModelType,
|
||||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||||
pub model_resource: Resource,
|
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||||
pub config_resource: Resource,
|
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||||
pub vocab_resource: Resource,
|
pub vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||||
/// Merges resource (default: None)
|
/// Merges resource (default: None)
|
||||||
pub merges_resource: Option<Resource>,
|
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||||
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
||||||
pub lower_case: bool,
|
pub lower_case: bool,
|
||||||
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
||||||
@ -151,27 +154,30 @@ impl ZeroShotClassificationConfig {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||||
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
|
/// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||||
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
/// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
pub fn new(
|
pub fn new<R>(
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: R,
|
||||||
config_resource: Resource,
|
config_resource: R,
|
||||||
vocab_resource: Resource,
|
vocab_resource: R,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<R>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
strip_accents: impl Into<Option<bool>>,
|
strip_accents: impl Into<Option<bool>>,
|
||||||
add_prefix_space: impl Into<Option<bool>>,
|
add_prefix_space: impl Into<Option<bool>>,
|
||||||
) -> ZeroShotClassificationConfig {
|
) -> ZeroShotClassificationConfig
|
||||||
|
where
|
||||||
|
R: ResourceProvider + Send + 'static,
|
||||||
|
{
|
||||||
ZeroShotClassificationConfig {
|
ZeroShotClassificationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource: Box::new(model_resource),
|
||||||
config_resource,
|
config_resource: Box::new(config_resource),
|
||||||
vocab_resource,
|
vocab_resource: Box::new(vocab_resource),
|
||||||
merges_resource,
|
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||||
lower_case,
|
lower_case,
|
||||||
strip_accents: strip_accents.into(),
|
strip_accents: strip_accents.into(),
|
||||||
add_prefix_space: add_prefix_space.into(),
|
add_prefix_space: add_prefix_space.into(),
|
||||||
@ -180,21 +186,22 @@ impl ZeroShotClassificationConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
impl Default for ZeroShotClassificationConfig {
|
impl Default for ZeroShotClassificationConfig {
|
||||||
/// Provides a defaultSST-2 sentiment analysis model (English)
|
/// Provides a defaultSST-2 sentiment analysis model (English)
|
||||||
fn default() -> ZeroShotClassificationConfig {
|
fn default() -> ZeroShotClassificationConfig {
|
||||||
ZeroShotClassificationConfig {
|
ZeroShotClassificationConfig {
|
||||||
model_type: ModelType::Bart,
|
model_type: ModelType::Bart,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
BartModelResources::BART_MNLI,
|
BartModelResources::BART_MNLI,
|
||||||
)),
|
)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
BartConfigResources::BART_MNLI,
|
BartConfigResources::BART_MNLI,
|
||||||
)),
|
)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
BartVocabResources::BART_MNLI,
|
BartVocabResources::BART_MNLI,
|
||||||
)),
|
)),
|
||||||
merges_resource: Some(Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||||
BartMergesResources::BART_MNLI,
|
BartMergesResources::BART_MNLI,
|
||||||
))),
|
))),
|
||||||
lower_case: false,
|
lower_case: false,
|
||||||
|
@ -20,17 +20,17 @@
|
|||||||
//! use rust_bert::prophetnet::{
|
//! use rust_bert::prophetnet::{
|
||||||
//! ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
//! ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
||||||
//! };
|
//! };
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use tch::Device;
|
//! use tch::Device;
|
||||||
//!
|
//!
|
||||||
//! fn main() -> anyhow::Result<()> {
|
//! fn main() -> anyhow::Result<()> {
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
//! ));
|
//! ));
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
//! ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
//! ));
|
//! ));
|
||||||
//! let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
//! ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
//! ));
|
//! ));
|
||||||
//!
|
//!
|
||||||
|
@ -18,8 +18,6 @@ use rust_tokenizers::vocab::{ProphetNetVocab, Vocab};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tch::{nn, Kind, Tensor};
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
|
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||||
PreparedInput, PrivateLanguageGenerator,
|
PreparedInput, PrivateLanguageGenerator,
|
||||||
@ -909,40 +907,9 @@ impl ProphetNetConditionalGenerator {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
generate_config: GenerateConfig,
|
generate_config: GenerateConfig,
|
||||||
) -> Result<ProphetNetConditionalGenerator, RustBertError> {
|
) -> Result<ProphetNetConditionalGenerator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
|
@ -19,19 +19,19 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::reformer::{ReformerConfig, ReformerModel};
|
//! use rust_bert::reformer::{ReformerConfig, ReformerModel};
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::ReformerTokenizer;
|
//! use rust_tokenizers::tokenizer::ReformerTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/weights.ot"),
|
//! local_path: PathBuf::from("path/to/weights.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
@ -23,8 +23,6 @@ use tch::{nn, Device, Kind, Tensor};
|
|||||||
use crate::common::activations::Activation;
|
use crate::common::activations::Activation;
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
|
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
|
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||||
PreparedInput, PrivateLanguageGenerator,
|
PreparedInput, PrivateLanguageGenerator,
|
||||||
@ -1019,40 +1017,9 @@ pub struct ReformerGenerator {
|
|||||||
|
|
||||||
impl ReformerGenerator {
|
impl ReformerGenerator {
|
||||||
pub fn new(generate_config: GenerateConfig) -> Result<ReformerGenerator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<ReformerGenerator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
|
@ -23,23 +23,23 @@
|
|||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bert::BertConfig;
|
//! use rust_bert::bert::BertConfig;
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::roberta::RobertaForMaskedLM;
|
//! use rust_bert::roberta::RobertaForMaskedLM;
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
|
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let vocab_resource = Resource::Local(LocalResource {
|
//! let vocab_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let merges_resource = Resource::Local(LocalResource {
|
//! let merges_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/merges.txt"),
|
//! local_path: PathBuf::from("path/to/merges.txt"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||||
//! let merges_path = merges_resource.get_local_path()?;
|
//! let merges_path = merges_resource.get_local_path()?;
|
||||||
|
@ -19,20 +19,20 @@
|
|||||||
//! #
|
//! #
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//! # use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::resources::{LocalResource, Resource};
|
//! use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||||
//! use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
|
//! use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_tokenizers::tokenizer::T5Tokenizer;
|
//! use rust_tokenizers::tokenizer::T5Tokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource {
|
//! let config_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/config.json"),
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! });
|
//! };
|
||||||
//! let sentence_piece_resource = Resource::Local(LocalResource {
|
//! let sentence_piece_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
//! });
|
//! };
|
||||||
//! let weights_resource = Resource::Local(LocalResource {
|
//! let weights_resource = LocalResource {
|
||||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
//! });
|
//! };
|
||||||
//! let config_path = config_resource.get_local_path()?;
|
//! let config_path = config_resource.get_local_path()?;
|
||||||
//! let spiece_path = sentence_piece_resource.get_local_path()?;
|
//! let spiece_path = sentence_piece_resource.get_local_path()?;
|
||||||
//! let weights_path = weights_resource.get_local_path()?;
|
//! let weights_path = weights_resource.get_local_path()?;
|
||||||
|
@ -18,8 +18,6 @@ use serde::{Deserialize, Serialize};
|
|||||||
use tch::nn::embedding;
|
use tch::nn::embedding;
|
||||||
use tch::{nn, Tensor};
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
use crate::common::resources::{RemoteResource, Resource};
|
|
||||||
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
|
|
||||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||||
PreparedInput, PrivateLanguageGenerator,
|
PreparedInput, PrivateLanguageGenerator,
|
||||||
@ -715,34 +713,9 @@ pub struct T5Generator {
|
|||||||
|
|
||||||
impl T5Generator {
|
impl T5Generator {
|
||||||
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
|
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
|
||||||
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
|
let config_path = generate_config.config_resource.get_local_path()?;
|
||||||
let model_resource = if generate_config.model_resource
|
let vocab_path = generate_config.vocab_resource.get_local_path()?;
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
|
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL))
|
|
||||||
} else {
|
|
||||||
generate_config.model_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_resource = if generate_config.config_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL))
|
|
||||||
} else {
|
|
||||||
generate_config.config_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_resource = if generate_config.vocab_resource
|
|
||||||
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
|
|
||||||
{
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL))
|
|
||||||
} else {
|
|
||||||
generate_config.vocab_resource.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
|
||||||
let weights_path = model_resource.get_local_path()?;
|
|
||||||
let device = generate_config.device;
|
let device = generate_config.device;
|
||||||
|
|
||||||
generate_config.validate();
|
generate_config.validate();
|
||||||
|
@ -22,18 +22,18 @@
|
|||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
|
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
|
||||||
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
//! use rust_bert::resources::{RemoteResource, Resource};
|
//! use rust_bert::resources::RemoteResource;
|
||||||
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||||
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! XLNetConfigResources::XLNET_BASE_CASED,
|
//! XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
//! ));
|
//! ));
|
||||||
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! XLNetVocabResources::XLNET_BASE_CASED,
|
//! XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
//! ));
|
//! ));
|
||||||
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! XLNetVocabResources::XLNET_BASE_CASED,
|
//! XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
//! ));
|
//! ));
|
||||||
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
//! let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
//! XLNetModelResources::XLNET_BASE_CASED,
|
//! XLNetModelResources::XLNET_BASE_CASED,
|
||||||
//! ));
|
//! ));
|
||||||
//! let generate_config = TextGenerationConfig {
|
//! let generate_config = TextGenerationConfig {
|
||||||
|
@ -6,7 +6,7 @@ use rust_bert::albert::{
|
|||||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||||
AlbertModelResources, AlbertVocabResources,
|
AlbertModelResources, AlbertVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -16,13 +16,13 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_masked_lm() -> anyhow::Result<()> {
|
fn albert_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertConfigResources::ALBERT_BASE_V2,
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertVocabResources::ALBERT_BASE_V2,
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertModelResources::ALBERT_BASE_V2,
|
AlbertModelResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -87,10 +87,10 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_sequence_classification() -> anyhow::Result<()> {
|
fn albert_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertConfigResources::ALBERT_BASE_V2,
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertVocabResources::ALBERT_BASE_V2,
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -153,10 +153,10 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_multiple_choice() -> anyhow::Result<()> {
|
fn albert_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertConfigResources::ALBERT_BASE_V2,
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertVocabResources::ALBERT_BASE_V2,
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -219,10 +219,10 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_token_classification() -> anyhow::Result<()> {
|
fn albert_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertConfigResources::ALBERT_BASE_V2,
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertVocabResources::ALBERT_BASE_V2,
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -286,10 +286,10 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_question_answering() -> anyhow::Result<()> {
|
fn albert_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertConfigResources::ALBERT_BASE_V2,
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
AlbertVocabResources::ALBERT_BASE_V2,
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -6,7 +6,7 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
|
|||||||
use rust_bert::pipelines::zero_shot_classification::{
|
use rust_bert::pipelines::zero_shot_classification::{
|
||||||
ZeroShotClassificationConfig, ZeroShotClassificationModel,
|
ZeroShotClassificationConfig, ZeroShotClassificationModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn bart_lm_model() -> anyhow::Result<()> {
|
fn bart_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartConfigResources::DISTILBART_CNN_6_6,
|
BartConfigResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartVocabResources::DISTILBART_CNN_6_6,
|
BartVocabResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartMergesResources::DISTILBART_CNN_6_6,
|
BartMergesResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartModelResources::DISTILBART_CNN_6_6,
|
BartModelResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -77,16 +77,16 @@ fn bart_lm_model() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bart_summarization_greedy() -> anyhow::Result<()> {
|
fn bart_summarization_greedy() -> anyhow::Result<()> {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartConfigResources::DISTILBART_CNN_6_6,
|
BartConfigResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartVocabResources::DISTILBART_CNN_6_6,
|
BartVocabResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartMergesResources::DISTILBART_CNN_6_6,
|
BartMergesResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartModelResources::DISTILBART_CNN_6_6,
|
BartModelResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let summarization_config = SummarizationConfig {
|
let summarization_config = SummarizationConfig {
|
||||||
@ -138,16 +138,16 @@ about exoplanets like K2-18b."];
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bart_summarization_beam_search() -> anyhow::Result<()> {
|
fn bart_summarization_beam_search() -> anyhow::Result<()> {
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartConfigResources::DISTILBART_CNN_6_6,
|
BartConfigResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartVocabResources::DISTILBART_CNN_6_6,
|
BartVocabResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartMergesResources::DISTILBART_CNN_6_6,
|
BartMergesResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
BartModelResources::DISTILBART_CNN_6_6,
|
BartModelResources::DISTILBART_CNN_6_6,
|
||||||
));
|
));
|
||||||
let summarization_config = SummarizationConfig {
|
let summarization_config = SummarizationConfig {
|
||||||
|
@ -11,7 +11,7 @@ use rust_bert::pipelines::ner::NERModel;
|
|||||||
use rust_bert::pipelines::question_answering::{
|
use rust_bert::pipelines::question_answering::{
|
||||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -21,12 +21,9 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_masked_lm() -> anyhow::Result<()> {
|
fn bert_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
let weights_resource = RemoteResource::from_pretrained(BertModelResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let weights_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let weights_path = weights_resource.get_local_path()?;
|
let weights_path = weights_resource.get_local_path()?;
|
||||||
@ -106,10 +103,8 @@ fn bert_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_sequence_classification() -> anyhow::Result<()> {
|
fn bert_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -170,10 +165,8 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_multiple_choice() -> anyhow::Result<()> {
|
fn bert_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -230,10 +223,8 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_token_classification() -> anyhow::Result<()> {
|
fn bert_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -295,10 +286,8 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_question_answering() -> anyhow::Result<()> {
|
fn bert_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(BertConfigResources::BERT);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let vocab_resource = RemoteResource::from_pretrained(BertVocabResources::BERT);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -422,11 +411,9 @@ fn bert_question_answering() -> anyhow::Result<()> {
|
|||||||
// Set-up question answering model
|
// Set-up question answering model
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
|
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||||
BertConfigResources::BERT_QA,
|
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
|
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
@ -3,7 +3,7 @@ use rust_bert::deberta::{
|
|||||||
DebertaForSequenceClassification, DebertaForTokenClassification, DebertaMergesResources,
|
DebertaForSequenceClassification, DebertaForTokenClassification, DebertaMergesResources,
|
||||||
DebertaModelResources, DebertaVocabResources,
|
DebertaModelResources, DebertaVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -14,16 +14,16 @@ extern crate anyhow;
|
|||||||
#[test]
|
#[test]
|
||||||
fn deberta_natural_language_inference() -> anyhow::Result<()> {
|
fn deberta_natural_language_inference() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaModelResources::DEBERTA_BASE_MNLI,
|
DebertaModelResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn deberta_masked_lm() -> anyhow::Result<()> {
|
fn deberta_masked_lm() -> anyhow::Result<()> {
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -142,13 +142,13 @@ fn deberta_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn deberta_for_token_classification() -> anyhow::Result<()> {
|
fn deberta_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -203,13 +203,13 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn deberta_for_question_answering() -> anyhow::Result<()> {
|
fn deberta_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -5,7 +5,7 @@ use rust_bert::distilbert::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -42,13 +42,13 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn distilbert_masked_lm() -> anyhow::Result<()> {
|
fn distilbert_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertConfigResources::DISTIL_BERT,
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertVocabResources::DISTIL_BERT,
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertModelResources::DISTIL_BERT,
|
DistilBertModelResources::DISTIL_BERT,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -123,10 +123,10 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn distilbert_for_question_answering() -> anyhow::Result<()> {
|
fn distilbert_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -188,10 +188,10 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn distilbert_for_token_classification() -> anyhow::Result<()> {
|
fn distilbert_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertConfigResources::DISTIL_BERT,
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
DistilBertVocabResources::DISTIL_BERT,
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -3,7 +3,7 @@ use rust_bert::gpt2::{
|
|||||||
Gpt2VocabResources,
|
Gpt2VocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -11,16 +11,16 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2ConfigResources::DISTIL_GPT2,
|
Gpt2ConfigResources::DISTIL_GPT2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2VocabResources::DISTIL_GPT2,
|
Gpt2VocabResources::DISTIL_GPT2,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2MergesResources::DISTIL_GPT2,
|
Gpt2MergesResources::DISTIL_GPT2,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
Gpt2ModelResources::DISTIL_GPT2,
|
Gpt2ModelResources::DISTIL_GPT2,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -2,7 +2,7 @@ use rust_bert::electra::{
|
|||||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
|
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
|
||||||
ElectraModelResources, ElectraVocabResources,
|
ElectraModelResources, ElectraVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -11,13 +11,13 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn electra_masked_lm() -> anyhow::Result<()> {
|
fn electra_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraConfigResources::BASE_GENERATOR,
|
ElectraConfigResources::BASE_GENERATOR,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraVocabResources::BASE_GENERATOR,
|
ElectraVocabResources::BASE_GENERATOR,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraModelResources::BASE_GENERATOR,
|
ElectraModelResources::BASE_GENERATOR,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -95,13 +95,13 @@ fn electra_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn electra_discriminator() -> anyhow::Result<()> {
|
fn electra_discriminator() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -7,7 +7,7 @@ use rust_bert::fnet::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
|
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{FNetTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -17,12 +17,9 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn fnet_masked_lm() -> anyhow::Result<()> {
|
fn fnet_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
||||||
let vocab_resource =
|
let weights_resource = Box::new(RemoteResource::from_pretrained(FNetModelResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
|
||||||
let weights_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetModelResources::BASE));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let weights_path = weights_resource.get_local_path()?;
|
let weights_path = weights_resource.get_local_path()?;
|
||||||
@ -85,13 +82,13 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fnet_for_sequence_classification() -> anyhow::Result<()> {
|
fn fnet_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Set up classifier
|
// Set up classifier
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetConfigResources::BASE_SST2,
|
FNetConfigResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetVocabResources::BASE_SST2,
|
FNetVocabResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
FNetModelResources::BASE_SST2,
|
FNetModelResources::BASE_SST2,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -128,10 +125,8 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fnet_for_multiple_choice() -> anyhow::Result<()> {
|
fn fnet_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -188,10 +183,8 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fnet_for_token_classification() -> anyhow::Result<()> {
|
fn fnet_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
@ -251,10 +244,8 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn fnet_for_question_answering() -> anyhow::Result<()> {
|
fn fnet_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetConfigResources::BASE));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(FNetVocabResources::BASE));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
|
|
||||||
|
122
tests/gpt2.rs
122
tests/gpt2.rs
@ -10,7 +10,7 @@ use rust_bert::pipelines::generation_utils::{
|
|||||||
Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator,
|
Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator,
|
||||||
};
|
};
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -18,14 +18,10 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_lm_model() -> anyhow::Result<()> {
|
fn gpt2_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource =
|
let config_resource = RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = RemoteResource::from_pretrained(Gpt2VocabResources::GPT2);
|
||||||
let vocab_resource =
|
let merges_resource = RemoteResource::from_pretrained(Gpt2MergesResources::GPT2);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let weights_resource = RemoteResource::from_pretrained(Gpt2ModelResources::GPT2);
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let weights_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -114,14 +110,10 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_greedy() -> anyhow::Result<()> {
|
fn gpt2_generation_greedy() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
@ -150,14 +142,10 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search() -> anyhow::Result<()> {
|
fn gpt2_generation_beam_search() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
@ -198,14 +186,10 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
|
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
@ -259,14 +243,10 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
@ -319,14 +299,10 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
@ -381,14 +357,10 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||||
let paragraph_tokens = [198, 628];
|
let paragraph_tokens = [198, 628];
|
||||||
@ -450,14 +422,10 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
|
fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
max_length: 36,
|
max_length: 36,
|
||||||
@ -520,14 +488,10 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
|
fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
max_length: 36,
|
max_length: 36,
|
||||||
@ -590,14 +554,10 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource =
|
let config_resource = Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
let vocab_resource =
|
let merges_resource = Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
|
||||||
let model_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
|
||||||
|
|
||||||
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||||
let paragraph_tokens = [198, 628];
|
let paragraph_tokens = [198, 628];
|
||||||
|
@ -4,7 +4,7 @@ use rust_bert::gpt_neo::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -12,16 +12,16 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt_neo_lm() -> anyhow::Result<()> {
|
fn gpt_neo_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoConfigResources::GPT_NEO_125M,
|
GptNeoConfigResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoVocabResources::GPT_NEO_125M,
|
GptNeoVocabResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoMergesResources::GPT_NEO_125M,
|
GptNeoMergesResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoModelResources::GPT_NEO_125M,
|
GptNeoModelResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -109,16 +109,16 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generation_gpt_neo() -> anyhow::Result<()> {
|
fn test_generation_gpt_neo() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoConfigResources::GPT_NEO_125M,
|
GptNeoConfigResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoVocabResources::GPT_NEO_125M,
|
GptNeoVocabResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoMergesResources::GPT_NEO_125M,
|
GptNeoMergesResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
GptNeoModelResources::GPT_NEO_125M,
|
GptNeoModelResources::GPT_NEO_125M,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::pipelines::question_answering::{
|
use rust_bert::pipelines::question_answering::{
|
||||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
|
||||||
@ -21,18 +21,14 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn longformer_masked_lm() -> anyhow::Result<()> {
|
fn longformer_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
LongformerConfigResources::LONGFORMER_BASE_4096,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
|
||||||
));
|
let vocab_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
|
||||||
LongformerVocabResources::LONGFORMER_BASE_4096,
|
let merges_resource =
|
||||||
));
|
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource =
|
||||||
LongformerMergesResources::LONGFORMER_BASE_4096,
|
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_4096);
|
||||||
));
|
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerModelResources::LONGFORMER_BASE_4096,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -176,15 +172,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn longformer_for_sequence_classification() -> anyhow::Result<()> {
|
fn longformer_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
LongformerConfigResources::LONGFORMER_BASE_4096,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
|
||||||
));
|
let vocab_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
|
||||||
LongformerVocabResources::LONGFORMER_BASE_4096,
|
let merges_resource =
|
||||||
));
|
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerMergesResources::LONGFORMER_BASE_4096,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -245,15 +238,12 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn longformer_for_multiple_choice() -> anyhow::Result<()> {
|
fn longformer_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
LongformerConfigResources::LONGFORMER_BASE_4096,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
|
||||||
));
|
let vocab_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
|
||||||
LongformerVocabResources::LONGFORMER_BASE_4096,
|
let merges_resource =
|
||||||
));
|
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerMergesResources::LONGFORMER_BASE_4096,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -321,15 +311,12 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
|
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
LongformerConfigResources::LONGFORMER_BASE_4096,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);
|
||||||
));
|
let vocab_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_4096);
|
||||||
LongformerVocabResources::LONGFORMER_BASE_4096,
|
let merges_resource =
|
||||||
));
|
RemoteResource::from_pretrained(LongformerMergesResources::LONGFORMER_BASE_4096);
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerMergesResources::LONGFORMER_BASE_4096,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -394,18 +381,12 @@ fn longformer_for_question_answering() -> anyhow::Result<()> {
|
|||||||
// Set-up Question Answering model
|
// Set-up Question Answering model
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Longformer,
|
ModelType::Longformer,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
|
||||||
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
|
||||||
)),
|
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
Some(RemoteResource::from_pretrained(
|
||||||
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
|
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
|
|
||||||
)),
|
|
||||||
Some(Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
|
||||||
))),
|
)),
|
||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
false,
|
false,
|
||||||
|
@ -4,7 +4,7 @@ use rust_bert::m2m_100::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{M2M100Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -12,18 +12,10 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn m2m100_lm_model() -> anyhow::Result<()> {
|
fn m2m100_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
|
||||||
M2M100ConfigResources::M2M100_418M,
|
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
|
||||||
));
|
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
|
||||||
M2M100VocabResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100ModelResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -76,18 +68,10 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn m2m100_translation() -> anyhow::Result<()> {
|
fn m2m100_translation() -> anyhow::Result<()> {
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
|
||||||
M2M100ModelResources::M2M100_418M,
|
let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = RemoteResource::from_pretrained(M2M100MergesResources::M2M100_418M);
|
||||||
M2M100ConfigResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100VocabResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
M2M100MergesResources::M2M100_418M,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
let source_languages = M2M100SourceLanguages::M2M100_418M;
|
||||||
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
let target_languages = M2M100TargetLanguages::M2M100_418M;
|
||||||
|
@ -6,25 +6,17 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::pipelines::translation::{
|
use rust_bert::pipelines::translation::{
|
||||||
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
|
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||||
fn test_translation() -> anyhow::Result<()> {
|
fn test_translation() -> anyhow::Result<()> {
|
||||||
// Set-up translation model
|
// Set-up translation model
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
|
||||||
MarianModelResources::ENGLISH2ROMANCE,
|
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
|
||||||
MarianConfigResources::ENGLISH2ROMANCE,
|
|
||||||
));
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MarianVocabResources::ENGLISH2ROMANCE,
|
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
MarianSpmResources::ENGLISH2ROMANCE,
|
|
||||||
));
|
|
||||||
|
|
||||||
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
|
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
|
||||||
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
|
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
|
||||||
|
@ -3,7 +3,7 @@ use rust_bert::mbart::{
|
|||||||
};
|
};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
|
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -11,13 +11,13 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn mbart_lm_model() -> anyhow::Result<()> {
|
fn mbart_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -5,7 +5,7 @@ use rust_bert::mobilebert::{
|
|||||||
MobileBertModelResources, MobileBertVocabResources,
|
MobileBertModelResources, MobileBertVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::pipelines::pos_tagging::POSModel;
|
use rust_bert::pipelines::pos_tagging::POSModel;
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||||
use rust_tokenizers::vocab::Vocab;
|
use rust_tokenizers::vocab::Vocab;
|
||||||
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_masked_model() -> anyhow::Result<()> {
|
fn mobilebert_masked_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertModelResources::MOBILEBERT_UNCASED,
|
MobileBertModelResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -111,10 +111,10 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
|
fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -162,10 +162,10 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
|
fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -220,10 +220,10 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
|
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -273,10 +273,10 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn mobilebert_for_question_answering() -> anyhow::Result<()> {
|
fn mobilebert_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -6,7 +6,7 @@ use rust_bert::openai_gpt::{
|
|||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{nn, Device, Tensor};
|
use tch::{nn, Device, Tensor};
|
||||||
@ -14,16 +14,16 @@ use tch::{nn, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptConfigResources::GPT,
|
OpenAiGptConfigResources::GPT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptVocabResources::GPT,
|
OpenAiGptVocabResources::GPT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptMergesResources::GPT,
|
OpenAiGptMergesResources::GPT,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptModelResources::GPT,
|
OpenAiGptModelResources::GPT,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -104,16 +104,16 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptConfigResources::GPT,
|
OpenAiGptConfigResources::GPT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptVocabResources::GPT,
|
OpenAiGptVocabResources::GPT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptMergesResources::GPT,
|
OpenAiGptMergesResources::GPT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptModelResources::GPT,
|
OpenAiGptModelResources::GPT,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -146,16 +146,16 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptConfigResources::GPT,
|
OpenAiGptConfigResources::GPT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptVocabResources::GPT,
|
OpenAiGptVocabResources::GPT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptMergesResources::GPT,
|
OpenAiGptMergesResources::GPT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptModelResources::GPT,
|
OpenAiGptModelResources::GPT,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -199,16 +199,16 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
|
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptConfigResources::GPT,
|
OpenAiGptConfigResources::GPT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptVocabResources::GPT,
|
OpenAiGptVocabResources::GPT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptMergesResources::GPT,
|
OpenAiGptMergesResources::GPT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptModelResources::GPT,
|
OpenAiGptModelResources::GPT,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -268,16 +268,16 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptConfigResources::GPT,
|
OpenAiGptConfigResources::GPT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptVocabResources::GPT,
|
OpenAiGptVocabResources::GPT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptMergesResources::GPT,
|
OpenAiGptMergesResources::GPT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
OpenAiGptModelResources::GPT,
|
OpenAiGptModelResources::GPT,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -2,19 +2,19 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
|
|||||||
|
|
||||||
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pegasus_summarization_greedy() -> anyhow::Result<()> {
|
fn pegasus_summarization_greedy() -> anyhow::Result<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusConfigResources::CNN_DAILYMAIL,
|
PegasusConfigResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusVocabResources::CNN_DAILYMAIL,
|
PegasusVocabResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
PegasusModelResources::CNN_DAILYMAIL,
|
PegasusModelResources::CNN_DAILYMAIL,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -4,19 +4,19 @@ use rust_bert::pipelines::common::ModelType;
|
|||||||
use rust_bert::prophetnet::{
|
use rust_bert::prophetnet::{
|
||||||
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
|
fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ use rust_bert::reformer::{
|
|||||||
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
|
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
|
||||||
ReformerForSequenceClassification, ReformerModelResources, ReformerVocabResources,
|
ReformerForSequenceClassification, ReformerModelResources, ReformerVocabResources,
|
||||||
};
|
};
|
||||||
use rust_bert::resources::{LocalResource, RemoteResource, Resource};
|
use rust_bert::resources::{LocalResource, RemoteResource, ResourceProvider};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
|
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -17,7 +17,7 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
fn test_generation_reformer() -> anyhow::Result<()> {
|
fn test_generation_reformer() -> anyhow::Result<()> {
|
||||||
// ===================================================
|
// ===================================================
|
||||||
// Modify resource to enforce seed
|
// Modify resource to enforce seed
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -31,18 +31,18 @@ fn test_generation_reformer() -> anyhow::Result<()> {
|
|||||||
let _ = updated_config_file.write_all(serde_json::to_string(&config).unwrap().as_bytes());
|
let _ = updated_config_file.write_all(serde_json::to_string(&config).unwrap().as_bytes());
|
||||||
let updated_config_path = updated_config_file.into_temp_path();
|
let updated_config_path = updated_config_file.into_temp_path();
|
||||||
|
|
||||||
let config_resource = Resource::Local(LocalResource {
|
let config_resource = Box::new(LocalResource {
|
||||||
local_path: updated_config_path.to_path_buf(),
|
local_path: updated_config_path.to_path_buf(),
|
||||||
});
|
});
|
||||||
// ===================================================
|
// ===================================================
|
||||||
|
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
// Set-up translation model
|
// Set-up translation model
|
||||||
@ -79,10 +79,10 @@ fn test_generation_reformer() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn reformer_for_sequence_classification() -> anyhow::Result<()> {
|
fn reformer_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -145,10 +145,10 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn reformer_for_question_answering() -> anyhow::Result<()> {
|
fn reformer_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
@ -5,7 +5,7 @@ use rust_bert::pipelines::question_answering::{
|
|||||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||||
};
|
};
|
||||||
use rust_bert::pipelines::token_classification::TokenClassificationConfig;
|
use rust_bert::pipelines::token_classification::TokenClassificationConfig;
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::roberta::{
|
use rust_bert::roberta::{
|
||||||
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||||
RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources,
|
RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources,
|
||||||
@ -20,18 +20,13 @@ use tch::{nn, no_grad, Device, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_masked_lm() -> anyhow::Result<()> {
|
fn roberta_masked_lm() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
RobertaConfigResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource =
|
||||||
RobertaVocabResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
|
||||||
));
|
let weights_resource =
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(RobertaModelResources::DISTILROBERTA_BASE);
|
||||||
RobertaMergesResources::DISTILROBERTA_BASE,
|
|
||||||
));
|
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaModelResources::DISTILROBERTA_BASE,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -116,15 +111,11 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_sequence_classification() -> anyhow::Result<()> {
|
fn roberta_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
RobertaConfigResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource =
|
||||||
RobertaVocabResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaMergesResources::DISTILROBERTA_BASE,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -190,15 +181,11 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_multiple_choice() -> anyhow::Result<()> {
|
fn roberta_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
RobertaConfigResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource =
|
||||||
RobertaVocabResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaMergesResources::DISTILROBERTA_BASE,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -260,15 +247,11 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_token_classification() -> anyhow::Result<()> {
|
fn roberta_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource =
|
||||||
RobertaConfigResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaConfigResources::DISTILROBERTA_BASE);
|
||||||
));
|
let vocab_resource = RemoteResource::from_pretrained(RobertaVocabResources::DISTILROBERTA_BASE);
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource =
|
||||||
RobertaVocabResources::DISTILROBERTA_BASE,
|
RemoteResource::from_pretrained(RobertaMergesResources::DISTILROBERTA_BASE);
|
||||||
));
|
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaMergesResources::DISTILROBERTA_BASE,
|
|
||||||
));
|
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
let vocab_path = vocab_resource.get_local_path()?;
|
let vocab_path = vocab_resource.get_local_path()?;
|
||||||
let merges_path = merges_resource.get_local_path()?;
|
let merges_path = merges_resource.get_local_path()?;
|
||||||
@ -337,18 +320,12 @@ fn roberta_question_answering() -> anyhow::Result<()> {
|
|||||||
// Set-up question answering model
|
// Set-up question answering model
|
||||||
let config = QuestionAnsweringConfig::new(
|
let config = QuestionAnsweringConfig::new(
|
||||||
ModelType::Roberta,
|
ModelType::Roberta,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
RemoteResource::from_pretrained(RobertaModelResources::ROBERTA_QA),
|
||||||
RobertaModelResources::ROBERTA_QA,
|
RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA_QA),
|
||||||
)),
|
RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA_QA),
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
Some(RemoteResource::from_pretrained(
|
||||||
RobertaConfigResources::ROBERTA_QA,
|
|
||||||
)),
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaVocabResources::ROBERTA_QA,
|
|
||||||
)),
|
|
||||||
Some(Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
RobertaMergesResources::ROBERTA_QA,
|
RobertaMergesResources::ROBERTA_QA,
|
||||||
))),
|
)),
|
||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
false,
|
false,
|
||||||
@ -378,13 +355,13 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
|
|||||||
// Set-up question answering model
|
// Set-up question answering model
|
||||||
let ner_config = TokenClassificationConfig {
|
let ner_config = TokenClassificationConfig {
|
||||||
model_type: ModelType::XLMRoberta,
|
model_type: ModelType::XLMRoberta,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
RobertaModelResources::XLM_ROBERTA_NER_DE,
|
RobertaModelResources::XLM_ROBERTA_NER_DE,
|
||||||
)),
|
)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
||||||
)),
|
)),
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||||
RobertaVocabResources::XLM_ROBERTA_NER_DE,
|
RobertaVocabResources::XLM_ROBERTA_NER_DE,
|
||||||
)),
|
)),
|
||||||
lower_case: false,
|
lower_case: false,
|
||||||
|
30
tests/t5.rs
30
tests/t5.rs
@ -1,20 +1,16 @@
|
|||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::RemoteResource;
|
||||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_translation_t5() -> anyhow::Result<()> {
|
fn test_translation_t5() -> anyhow::Result<()> {
|
||||||
let model_resource =
|
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
|
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
|
||||||
let config_resource =
|
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
|
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
|
||||||
let vocab_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
|
||||||
let merges_resource =
|
|
||||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
|
||||||
|
|
||||||
let source_languages = [
|
let source_languages = [
|
||||||
Language::English,
|
Language::English,
|
||||||
@ -70,18 +66,10 @@ fn test_summarization_t5() -> anyhow::Result<()> {
|
|||||||
// Set-up translation model
|
// Set-up translation model
|
||||||
let summarization_config = SummarizationConfig {
|
let summarization_config = SummarizationConfig {
|
||||||
model_type: ModelType::T5,
|
model_type: ModelType::T5,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
|
||||||
T5ModelResources::T5_SMALL,
|
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
|
||||||
)),
|
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
|
||||||
T5ConfigResources::T5_SMALL,
|
|
||||||
)),
|
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
T5VocabResources::T5_SMALL,
|
|
||||||
)),
|
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
|
||||||
T5VocabResources::T5_SMALL,
|
|
||||||
)),
|
|
||||||
min_length: 30,
|
min_length: 30,
|
||||||
max_length: 200,
|
max_length: 200,
|
||||||
early_stopping: true,
|
early_stopping: true,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||||
use rust_bert::resources::{RemoteResource, Resource};
|
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||||
use rust_bert::xlnet::{
|
use rust_bert::xlnet::{
|
||||||
XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
|
XLNetConfig, XLNetConfigResources, XLNetForMultipleChoice, XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel,
|
XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel,
|
||||||
@ -15,13 +15,13 @@ use tch::{nn, no_grad, Device, Kind, Tensor};
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_base_model() -> anyhow::Result<()> {
|
fn xlnet_base_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetModelResources::XLNET_BASE_CASED,
|
XLNetModelResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -122,13 +122,13 @@ fn xlnet_base_model() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_lm_model() -> anyhow::Result<()> {
|
fn xlnet_lm_model() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let weights_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetModelResources::XLNET_BASE_CASED,
|
XLNetModelResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -196,16 +196,16 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_generation_beam_search() -> anyhow::Result<()> {
|
fn xlnet_generation_beam_search() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetModelResources::XLNET_BASE_CASED,
|
XLNetModelResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -239,10 +239,10 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
|
fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -311,10 +311,10 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
|
fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -379,10 +379,10 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_for_token_classification() -> anyhow::Result<()> {
|
fn xlnet_for_token_classification() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
@ -442,10 +442,10 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn xlnet_for_question_answering() -> anyhow::Result<()> {
|
fn xlnet_for_question_answering() -> anyhow::Result<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetConfigResources::XLNET_BASE_CASED,
|
XLNetConfigResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||||
XLNetVocabResources::XLNET_BASE_CASED,
|
XLNetVocabResources::XLNET_BASE_CASED,
|
||||||
));
|
));
|
||||||
let config_path = config_resource.get_local_path()?;
|
let config_path = config_resource.get_local_path()?;
|
||||||
|
Loading…
Reference in New Issue
Block a user