Merge branch 'master' into clippy_check

This commit is contained in:
Guillaume B 2020-09-12 14:40:15 +02:00
commit eee7985dce
50 changed files with 444 additions and 463 deletions

View File

@ -9,6 +9,10 @@ jobs:
- cargo fmt -- --check
- script:
- cargo build --verbose
- os:
- windows
script:
- cargo build --verbose
- script:
- sudo apt-get install python3-pip python3-setuptools
- pip3 install --upgrade pip

View File

@ -37,10 +37,9 @@ serde = { version = "1.0.114", features = ["derive"] }
dirs = "3.0.1"
itertools = "0.9.0"
ordered-float = "2.0.0"
reqwest = "0.10.7"
cached-path = "0.4.3"
lazy_static = "1.4.0"
uuid = { version = "0.8.1", features = ["v4"] }
tokio = { version = "0.2.22", features = ["full"] }
thiserror = "1.0.20"
[dev-dependencies]

View File

@ -17,7 +17,7 @@ use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
@ -33,9 +33,9 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -16,7 +16,7 @@ use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
@ -31,10 +31,10 @@ fn main() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();

View File

@ -15,7 +15,7 @@ extern crate anyhow;
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
@ -28,9 +28,9 @@ fn main() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -15,7 +15,7 @@ use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
@ -33,9 +33,9 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -16,7 +16,7 @@ use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
@ -40,10 +40,10 @@ fn download_distil_gpt2() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -58,9 +58,9 @@ fn download_distilbert_sst2() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -75,9 +75,9 @@ fn download_distilbert_qa() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -92,9 +92,9 @@ fn download_distilbert() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -108,10 +108,10 @@ fn download_gpt2() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -129,10 +129,10 @@ fn download_gpt() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -150,10 +150,10 @@ fn download_roberta() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -165,9 +165,9 @@ fn download_bert() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -182,9 +182,9 @@ fn download_bert_ner() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -198,10 +198,10 @@ fn download_bart() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -219,10 +219,10 @@ fn download_bart_cnn() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -237,9 +237,9 @@ fn download_electra_generator() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -254,9 +254,9 @@ fn download_electra_discriminator() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -271,9 +271,9 @@ fn download_albert_base_v2() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -291,10 +291,10 @@ fn _download_dialogpt() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -306,9 +306,9 @@ fn download_t5_small() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -326,10 +326,10 @@ fn download_roberta_qa() -> anyhow::Result<()> {
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -342,9 +342,9 @@ fn download_bert_qa() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
@ -359,9 +359,9 @@ fn download_xlm_roberta_ner_german() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}

View File

@ -16,7 +16,7 @@ use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
@ -32,9 +32,9 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -16,7 +16,7 @@ use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
@ -32,9 +32,9 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -17,7 +17,7 @@ use rust_bert::gpt2::{
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -32,10 +32,10 @@ fn main() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -18,7 +18,7 @@ use rust_bert::openai_gpt::{
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -37,10 +37,10 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::bert::BertConfig;
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
@ -36,10 +36,10 @@ fn main() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -33,7 +33,7 @@ pub struct AlbertVocabResources;
impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model.ot",
"albert-base-v2/model",
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
);
}
@ -41,7 +41,7 @@ impl AlbertModelResources {
impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config.json",
"albert-base-v2/config",
"https://cdn.huggingface.co/albert-base-v2-config.json",
);
}
@ -49,7 +49,7 @@ impl AlbertConfigResources {
impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece.model",
"albert-base-v2/spiece",
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
);
}

View File

@ -26,7 +26,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -38,9 +38,9 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: AlbertTokenizer =

View File

@ -39,22 +39,22 @@ pub struct BartMergesResources;
impl BartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/model.ot",
"bart/model",
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/model.ot",
"bart-cnn/model",
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/model.ot",
"bart-xsum/model",
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/model.ot",
"bart-large-mnli/model",
"https://cdn.huggingface.co/facebook/bart-large-mnli/rust_model.ot",
);
}
@ -62,22 +62,22 @@ impl BartModelResources {
impl BartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/config.json",
"bart/config",
"https://cdn.huggingface.co/facebook/bart-large/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/config.json",
"bart-cnn/config",
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/config.json",
"bart-xsum/config",
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/config.json",
"bart-large-mnli/config",
"https://cdn.huggingface.co/facebook/bart-large-mnli/config.json",
);
}
@ -85,22 +85,22 @@ impl BartConfigResources {
impl BartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/vocab.txt",
"bart/vocab",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/vocab.txt",
"bart-cnn/vocab",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/vocab.txt",
"bart-xsum/vocab",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/vocab.txt",
"bart-large-mnli/vocab",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
}
@ -108,22 +108,22 @@ impl BartVocabResources {
impl BartMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = (
"bart/merges.txt",
"bart/merges",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/merges.txt",
"bart-cnn/merges",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/merges.txt",
"bart-xsum/merges",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_MNLI: (&'static str, &'static str) = (
"bart-large-mnli/merges.txt",
"bart-large-mnli/merges",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
}

View File

@ -21,7 +21,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -36,10 +36,10 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -36,17 +36,17 @@ pub struct BertVocabResources;
impl BertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/model.ot",
"bert/model",
"https://cdn.huggingface.co/bert-base-uncased-rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model.ot",
"bert-ner/model",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model.ot",
"bert-qa/model",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-rust_model.ot",
);
}
@ -54,17 +54,17 @@ impl BertModelResources {
impl BertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/config.json",
"bert/config",
"https://cdn.huggingface.co/bert-base-uncased-config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config.json",
"bert-ner/config",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config.json",
"bert-qa/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
);
}
@ -72,17 +72,17 @@ impl BertConfigResources {
impl BertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/vocab.txt",
"bert/vocab",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab.txt",
"bert-ner/vocab",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab.txt",
"bert-qa/vocab",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
);
}

View File

@ -25,7 +25,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -37,9 +37,9 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer =

View File

@ -36,7 +36,7 @@ where
/// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path);
/// ```
fn from_file(path: &Path) -> T {
fn from_file<P: AsRef<Path>>(path: P) -> T {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration");

View File

@ -23,8 +23,8 @@ pub enum RustBertError {
ValueError(String),
}
impl From<reqwest::Error> for RustBertError {
fn from(error: reqwest::Error) -> Self {
impl From<cached_path::Error> for RustBertError {
fn from(error: cached_path::Error) -> Self {
RustBertError::FileDownloadError(error.to_string())
}
}

View File

@ -18,17 +18,14 @@
//! pre-trained models in each model module.
use crate::common::error::RustBertError;
use cached_path::Cache;
use lazy_static::lazy_static;
use reqwest::Client;
use std::env;
use std::path::PathBuf;
use std::{env, fs};
use tokio::prelude::*;
use tokio::runtime::Runtime;
use tokio::task;
extern crate dirs;
/// # Resource Enum expected by the `download_resource` function
/// # Resource Enum pointing to model, configuration or vocabulary resources
/// Can be of type:
/// - LocalResource
/// - RemoteResource
@ -39,7 +36,10 @@ pub enum Resource {
}
impl Resource {
/// Gets the local path for a given resource
/// Gets the local path for a given resource.
///
/// If the resource is a remote resource, it is downloaded and cached. Then the path
/// to the local cache is returned.
///
/// # Returns
///
@ -55,10 +55,14 @@ impl Resource {
/// });
/// let config_path = config_resource.get_local_path();
/// ```
pub fn get_local_path(&self) -> &PathBuf {
pub fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
match self {
Resource::Local(resource) => &resource.local_path,
Resource::Remote(resource) => &resource.local_path,
Resource::Local(resource) => Ok(resource.local_path.clone()),
Resource::Remote(resource) => {
let cached_path =
CACHE.cached_path_in_subdir(&resource.url, Some(&resource.cache_subdir))?;
Ok(cached_path)
}
}
}
}
@ -75,8 +79,8 @@ pub struct LocalResource {
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local path for the resource
pub local_path: PathBuf,
/// Local subdirectory of the cache root where this resource is saved
pub cache_subdir: String,
}
impl RemoteResource {
@ -86,7 +90,7 @@ impl RemoteResource {
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `target` - `PathBuf` Local path to save teh resource to
/// * `cache_subdir` - `&str` Local subdirectory of the cache root to save the resource to
///
/// # Returns
///
@ -96,16 +100,15 @@ impl RemoteResource {
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "http://config_json_location",
/// PathBuf::from("path/to/config.json"),
/// "configs",
/// ));
/// ```
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
url: url.to_string(),
local_path: target,
cache_subdir: cache_subdir.to_string(),
}
}
@ -126,16 +129,17 @@ impl RemoteResource {
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "distilbert-sst2",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0;
let name = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
let mut local_path = CACHE_DIRECTORY.to_path_buf();
local_path.push(name);
RemoteResource { url, local_path }
RemoteResource {
url,
cache_subdir: name,
}
}
}
@ -144,7 +148,7 @@ lazy_static! {
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `~/.cache/.rustbert`.
pub static ref CACHE_DIRECTORY: PathBuf = _get_cache_directory();
pub static ref CACHE: Cache = Cache::builder().dir(_get_cache_directory()).build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
@ -160,6 +164,10 @@ fn _get_cache_directory() -> PathBuf {
home
}
#[deprecated(
since = "0.9.1",
note = "Please use `Resource.get_local_path()` instead"
)]
/// # (Download) the resource and return a path to its local path
/// This function will download remote resource to their local path if they do not exist yet.
/// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource.
@ -176,37 +184,13 @@ fn _get_cache_directory() -> PathBuf {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{download_resource, RemoteResource, Resource};
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// let local_path = download_resource(&model_resource);
/// let local_path = model_resource.get_local_path();
/// ```
pub fn download_resource(resource: &Resource) -> Result<&PathBuf, RustBertError> {
match resource {
Resource::Remote(remote_resource) => {
let target = remote_resource.local_path.clone();
let url = remote_resource.url.clone();
if !target.exists() {
println!("Downloading {} to {:?}", url, target);
fs::create_dir_all(target.parent().unwrap())?;
let mut rt = Runtime::new()?;
let local = task::LocalSet::new();
local.block_on(&mut rt, async {
let client = Client::new();
let output_file = tokio::fs::File::create(target).await?;
let mut output_file = tokio::io::BufWriter::new(output_file);
let mut response = client.get(&url).send().await?;
while let Some(chunk) = response.chunk().await? {
output_file.write(&chunk).await?;
}
output_file.flush().await?;
Ok::<(), RustBertError>(())
})?;
}
Ok(resource.get_local_path())
}
Resource::Local(_) => Ok(resource.get_local_path()),
}
pub fn download_resource(resource: &Resource) -> Result<PathBuf, RustBertError> {
resource.get_local_path()
}

View File

@ -32,17 +32,17 @@ pub struct DistilBertVocabResources;
impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model.ot",
"distilbert-sst2/model",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model.ot",
"distilbert/model",
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model.ot",
"distilbert-qa/model",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
);
}
@ -50,17 +50,17 @@ impl DistilBertModelResources {
impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config.json",
"distilbert-sst2/config",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config.json",
"distilbert/config",
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config.json",
"distilbert-qa/config",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
);
}
@ -68,17 +68,17 @@ impl DistilBertConfigResources {
impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab.txt",
"distilbert-sst2/vocab",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab.txt",
"distilbert/vocab",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab.txt",
"distilbert-qa/vocab",
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
);
}

View File

@ -27,7 +27,7 @@
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources,
//! };
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -39,9 +39,9 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer =

View File

@ -34,12 +34,12 @@ pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model.ot",
"electra-base-generator/model",
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model.ot",
"electra-base-discriminator/model",
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
);
}
@ -47,12 +47,12 @@ impl ElectraModelResources {
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config.json",
"electra-base-generator/config",
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config.json",
"electra-base-discriminator/config",
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
);
}
@ -60,12 +60,12 @@ impl ElectraConfigResources {
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab.txt",
"electra-base-generator/vocab",
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab.txt",
"electra-base-discriminator/vocab",
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
);
}

View File

@ -29,7 +29,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -41,9 +41,9 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer =

View File

@ -38,131 +38,125 @@ pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model.ot",
"gpt2/model",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model.ot",
"gpt2-medium/model",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model.ot",
"gpt2-large/model",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model.ot",
"gpt2-xl/model",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model.ot",
"distilgpt2/model",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model.ot",
"dialogpt-medium/model",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
);
}
impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config.json",
"https://cdn.huggingface.co/gpt2-config.json",
);
pub const GPT2: (&'static str, &'static str) =
("gpt2/config", "https://cdn.huggingface.co/gpt2-config.json");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config.json",
"gpt2-medium/config",
"https://cdn.huggingface.co/gpt2-medium-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config.json",
"gpt2-large/config",
"https://cdn.huggingface.co/gpt2-large-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config.json",
"gpt2-xl/config",
"https://cdn.huggingface.co/gpt2-xl-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config.json",
"distilgpt2/config",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config.json",
"dialogpt-medium/config",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
);
}
impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab.txt",
"https://cdn.huggingface.co/gpt2-vocab.json",
);
pub const GPT2: (&'static str, &'static str) =
("gpt2/vocab", "https://cdn.huggingface.co/gpt2-vocab.json");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab.txt",
"gpt2-medium/vocab",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab.txt",
"gpt2-large/vocab",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab.txt",
"gpt2-xl/vocab",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab.txt",
"distilgpt2/vocab",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab.txt",
"dialogpt-medium/vocab",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
);
}
impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges.txt",
"https://cdn.huggingface.co/gpt2-merges.txt",
);
pub const GPT2: (&'static str, &'static str) =
("gpt2/merges", "https://cdn.huggingface.co/gpt2-merges.txt");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges.txt",
"gpt2-medium/merges",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges.txt",
"gpt2-large/merges",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges.txt",
"gpt2-xl/merges",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges.txt",
"distilgpt2/merges",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges.txt",
"dialogpt-medium/merges",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
);
}

View File

@ -20,7 +20,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -35,10 +35,10 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -35,42 +35,42 @@ pub struct MarianPrefix;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/model.ot",
"marian-mt-en-ROMANCE/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/model.ot",
"marian-mt-ROMANCE-en/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/model.ot",
"marian-mt-en-de/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/model.ot",
"marian-mt-de-en/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/model.ot",
"marian-mt-en-ru/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/model.ot",
"marian-mt-ru-en/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/model.ot",
"marian-mt-fr-de/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/model.ot",
"marian-mt-de-fr/model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
);
}
@ -78,42 +78,42 @@ impl MarianModelResources {
impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/config.json",
"marian-mt-en-ROMANCE/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/config.json",
"marian-mt-ROMANCE-en/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/config.json",
"marian-mt-en-de/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/config.json",
"marian-mt-de-en/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/config.json",
"marian-mt-en-ru/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/config.json",
"marian-mt-ru-en/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/config.json",
"marian-mt-fr-de/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/config.json",
"marian-mt-de-fr/config",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
);
}
@ -121,42 +121,42 @@ impl MarianConfigResources {
impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/vocab.json",
"marian-mt-en-ROMANCE/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/vocab.json",
"marian-mt-ROMANCE-en/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/vocab.json",
"marian-mt-en-de/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/vocab.json",
"marian-mt-de-en/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/vocab.json",
"marian-mt-en-ru/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/vocab.json",
"marian-mt-ru-en/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/vocab.json",
"marian-mt-fr-de/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/vocab.json",
"marian-mt-de-fr/vocab",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
);
}
@ -164,42 +164,42 @@ impl MarianVocabResources {
impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/spiece.model",
"marian-mt-en-ROMANCE/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/spiece.model",
"marian-mt-ROMANCE-en/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/spiece.model",
"marian-mt-en-de/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/spiece.model",
"marian-mt-de-en/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/spiece.model",
"marian-mt-en-ru/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/spiece.model",
"marian-mt-ru-en/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/spiece.model",
"marian-mt-fr-de/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/spiece.model",
"marian-mt-de-fr/spiece",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
);
}

View File

@ -21,7 +21,7 @@
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//!
@ -37,10 +37,10 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -20,7 +20,7 @@
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource {
@ -35,10 +35,10 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -37,7 +37,7 @@ pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model.ot",
"openai-gpt/model",
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
);
}
@ -45,7 +45,7 @@ impl OpenAiGptModelResources {
impl OpenAiGptConfigResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config.json",
"openai-gpt/config",
"https://cdn.huggingface.co/openai-gpt-config.json",
);
}
@ -53,7 +53,7 @@ impl OpenAiGptConfigResources {
impl OpenAiGptVocabResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab.txt",
"openai-gpt/vocab",
"https://cdn.huggingface.co/openai-gpt-vocab.json",
);
}
@ -61,7 +61,7 @@ impl OpenAiGptVocabResources {
impl OpenAiGptMergesResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges.txt",
"openai-gpt/merges",
"https://cdn.huggingface.co/openai-gpt-merges.txt",
);
}

View File

@ -92,7 +92,7 @@ pub enum TokenizerOption {
impl ConfigOption {
/// Interface method to load a configuration from file
pub fn from_file(model_type: ModelType, path: &Path) -> Self {
pub fn from_file<P: AsRef<Path>>(model_type: ModelType, path: P) -> Self {
match model_type {
ModelType::Bart => ConfigOption::Bart(BartConfig::from_file(path)),
ModelType::Bert | ModelType::Roberta | ModelType::XLMRoberta => {

View File

@ -63,7 +63,7 @@ use crate::bart::{
BartModelResources, BartVocabResources, LayerState as BartLayerState,
};
use crate::common::error::RustBertError;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
@ -285,10 +285,10 @@ impl OpenAIGenerator {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
@ -403,10 +403,10 @@ impl GPT2Generator {
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let config_path = download_resource(&generate_config.config_resource)?;
let vocab_path = download_resource(&generate_config.vocab_resource)?;
let merges_path = download_resource(&generate_config.merges_resource)?;
let weights_path = download_resource(&generate_config.model_resource)?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -607,10 +607,10 @@ impl BartGenerator {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -881,10 +881,10 @@ impl MarianGenerator {
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<MarianGenerator, RustBertError> {
let config_path = download_resource(&generate_config.config_resource)?;
let vocab_path = download_resource(&generate_config.vocab_resource)?;
let sentence_piece_path = download_resource(&generate_config.merges_resource)?;
let weights_path = download_resource(&generate_config.model_resource)?;
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let sentence_piece_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -1143,9 +1143,9 @@ impl T5Generator {
generate_config.vocab_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&model_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();

View File

@ -46,7 +46,7 @@
use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering;
use crate::common::error::RustBertError;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
@ -409,12 +409,12 @@ impl QuestionAnsweringModel {
pub fn new(
question_answering_config: QuestionAnsweringConfig,
) -> Result<QuestionAnsweringModel, RustBertError> {
let config_path = download_resource(&question_answering_config.config_resource)?;
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
let weights_path = download_resource(&question_answering_config.model_resource)?;
let config_path = question_answering_config.config_resource.get_local_path()?;
let vocab_path = question_answering_config.vocab_resource.get_local_path()?;
let weights_path = question_answering_config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &question_answering_config.merges_resource
{
Some(download_resource(merges_resource).expect("Failure downloading resource"))
Some(merges_resource.get_local_path()?)
} else {
None
};
@ -423,7 +423,7 @@ impl QuestionAnsweringModel {
let tokenizer = TokenizerOption::from_file(
question_answering_config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
question_answering_config.lower_case,
question_answering_config.strip_accents,
question_answering_config.add_prefix_space,

View File

@ -61,7 +61,7 @@ use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::common::error::RustBertError;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
@ -377,11 +377,11 @@ impl SequenceClassificationModel {
pub fn new(
config: SequenceClassificationConfig,
) -> Result<SequenceClassificationModel, RustBertError> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let config_path = config.config_resource.get_local_path()?;
let vocab_path = config.vocab_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
Some(merges_resource.get_local_path()?)
} else {
None
};
@ -390,7 +390,7 @@ impl SequenceClassificationModel {
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,

View File

@ -116,7 +116,7 @@ use crate::bert::{
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::common::error::RustBertError;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
@ -485,11 +485,11 @@ impl TokenClassificationModel {
pub fn new(
config: TokenClassificationConfig,
) -> Result<TokenClassificationModel, RustBertError> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let config_path = config.config_resource.get_local_path()?;
let vocab_path = config.vocab_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
Some(merges_resource.get_local_path()?)
} else {
None
};
@ -499,7 +499,7 @@ impl TokenClassificationModel {
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,

View File

@ -107,7 +107,7 @@ use crate::bert::BertForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::sequence_classification::Label;
use crate::resources::{download_resource, RemoteResource, Resource};
use crate::resources::{RemoteResource, Resource};
use crate::roberta::RobertaForSequenceClassification;
use crate::RustBertError;
use itertools::Itertools;
@ -407,11 +407,11 @@ impl ZeroShotClassificationModel {
pub fn new(
config: ZeroShotClassificationConfig,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let config_path = config.config_resource.get_local_path()?;
let vocab_path = config.vocab_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
Some(merges_resource.get_local_path()?)
} else {
None
};
@ -420,7 +420,7 @@ impl ZeroShotClassificationModel {
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,

View File

@ -25,7 +25,7 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::Config;
//!
@ -41,10 +41,10 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let merges_path = merges_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -35,32 +35,32 @@ pub struct RobertaMergesResources;
impl RobertaModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model.ot",
"roberta/model",
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/model.ot",
"roberta-qa/model",
"https://cdn.huggingface.co/deepset/roberta-base-squad2/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/model.ot",
"xlm-roberta-ner-en/model",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/model.ot",
"xlm-roberta-ner-de/model",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/model.ot",
"xlm-roberta-ner-nl/model",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/model.ot",
"xlm-roberta-ner-es/model",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-rust_model.ot",
);
}
@ -68,32 +68,32 @@ impl RobertaModelResources {
impl RobertaConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config.json",
"roberta/config",
"https://cdn.huggingface.co/roberta-base-config.json",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/config.json",
"roberta-qa/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/deepset/roberta-base-squad2/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/config.json",
"xlm-roberta-ner-en/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/config.json",
"xlm-roberta-ner-de/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/config.json",
"xlm-roberta-ner-nl/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/config.json",
"xlm-roberta-ner-es/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
);
}
@ -101,32 +101,32 @@ impl RobertaConfigResources {
impl RobertaVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab.txt",
"roberta/vocab",
"https://cdn.huggingface.co/roberta-base-vocab.json",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/vocab.json",
"roberta-qa/vocab",
"https://cdn.huggingface.co/deepset/roberta-base-squad2/vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/spiece.model",
"xlm-roberta-ner-en/spiece",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/spiece.model",
"xlm-roberta-ner-de/spiece",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/spiece.model",
"xlm-roberta-ner-nl/spiece",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/spiece.model",
"xlm-roberta-ner-es/spiece",
"https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
);
}
@ -134,12 +134,12 @@ impl RobertaVocabResources {
impl RobertaMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges.txt",
"roberta/merges",
"https://cdn.huggingface.co/roberta-base-merges.txt",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/merges.txt",
"roberta-qa/merges",
"https://cdn.huggingface.co/deepset/roberta-base-squad2/merges.txt",
);
}

View File

@ -19,7 +19,7 @@
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::t5::{T5Config, T5ForConditionalGeneration};
//! use rust_bert::Config;
//! use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;
@ -33,9 +33,9 @@
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let config_path = config_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);

View File

@ -33,12 +33,12 @@ pub struct T5Prefix;
impl T5ModelResources {
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/model.ot",
"t5-small/model",
"https://cdn.huggingface.co/t5-small/rust_model.ot",
);
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/model.ot",
"t5-base/model",
"https://cdn.huggingface.co/t5-base/rust_model.ot",
);
}
@ -46,12 +46,12 @@ impl T5ModelResources {
impl T5ConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/config.json",
"t5-small/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/config.json",
"t5-base/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
);
}
@ -59,12 +59,12 @@ impl T5ConfigResources {
impl T5VocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/spiece.model",
"t5-small/spiece",
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_BASE: (&'static str, &'static str) = (
"t5-base/spiece.model",
"t5-base/spiece",
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
);
}

View File

@ -6,7 +6,7 @@ use rust_bert::albert::{
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
@ -24,9 +24,9 @@ fn albert_masked_lm() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
@ -85,8 +85,8 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -152,8 +152,8 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -219,8 +219,8 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -287,8 +287,8 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;

View File

@ -6,7 +6,7 @@ use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationMode
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -23,10 +23,10 @@ fn bart_lm_model() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -11,7 +11,7 @@ use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
@ -26,9 +26,9 @@ fn bert_masked_lm() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
@ -102,8 +102,8 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -167,8 +167,8 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -229,8 +229,8 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -295,8 +295,8 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;

View File

@ -5,7 +5,7 @@ use rust_bert::distilbert::{
};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
@ -52,9 +52,9 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
@ -123,8 +123,8 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
@ -183,8 +183,8 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();

View File

@ -3,7 +3,7 @@ use rust_bert::gpt2::{
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -23,10 +23,10 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -2,7 +2,7 @@ use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
ElectraModelResources, ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
@ -19,9 +19,9 @@ fn electra_masked_lm() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
@ -93,9 +93,9 @@ fn electra_discriminator() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -8,7 +8,7 @@ use rust_bert::pipelines::conversation::{
use rust_bert::pipelines::generation::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -24,10 +24,10 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -6,7 +6,7 @@ use rust_bert::openai_gpt::{
use rust_bert::pipelines::generation::{
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
@ -26,10 +26,10 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;

View File

@ -5,7 +5,7 @@ use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::pipelines::token_classification::TokenClassificationConfig;
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForSequenceClassification, RobertaForTokenClassification, RobertaMergesResources,
@ -31,10 +31,10 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
@ -117,9 +117,9 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -192,9 +192,9 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -264,9 +264,9 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;