Updated resource list

This commit is contained in:
Guillaume B 2020-05-03 14:59:13 +02:00
parent 139eecace7
commit 83e43ffcd5

View File

@ -7,6 +7,7 @@ use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaM
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
@ -14,7 +15,7 @@ use rust_bert::resources::{Resource, download_resource, RemoteResource};
fn download_distil_gpt2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
@ -27,7 +28,7 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
}
fn download_distilbert_sst2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
@ -38,7 +39,7 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
}
fn download_distilbert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
@ -49,7 +50,7 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
}
fn download_distilbert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
@ -60,7 +61,7 @@ fn download_distilbert() -> failure::Fallible<()> {
}
fn download_gpt2() -> failure::Fallible<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
@ -73,7 +74,7 @@ fn download_gpt2() -> failure::Fallible<()> {
}
fn download_gpt() -> failure::Fallible<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
@ -86,7 +87,7 @@ fn download_gpt() -> failure::Fallible<()> {
}
fn download_roberta() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
@ -99,7 +100,7 @@ fn download_roberta() -> failure::Fallible<()> {
}
fn download_bert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
@ -110,7 +111,7 @@ fn download_bert() -> failure::Fallible<()> {
}
fn download_bert_ner() -> failure::Fallible<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
@ -121,7 +122,7 @@ fn download_bert_ner() -> failure::Fallible<()> {
}
fn download_bart() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
@ -134,7 +135,7 @@ fn download_bart() -> failure::Fallible<()> {
}
fn download_bart_cnn() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
@ -146,6 +147,28 @@ fn download_bart_cnn() -> failure::Fallible<()> {
Ok(())
}
fn download_electra_generator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_electra_discriminator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
@ -158,6 +181,8 @@ fn main() -> failure::Fallible<()> {
let _ = download_bert_ner();
let _ = download_bart();
let _ = download_bart_cnn();
let _ = download_electra_generator();
let _ = download_electra_discriminator();
Ok(())
}