Addition of Electra resources

This commit is contained in:
Guillaume B 2020-05-03 10:05:32 +02:00
parent 5a1c1ae7a0
commit 029d4bd47c
5 changed files with 48 additions and 21 deletions

View File

@ -13,22 +13,17 @@
// limitations under the License.
use rust_bert::resources::{LocalResource, Resource, download_resource};
use std::path::PathBuf;
use rust_bert::electra::electra::{ElectraConfig, ElectraDiscriminator};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{Tensor, Device, nn, no_grad};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("electra-discriminator");
let config_resource = Resource::Local(LocalResource { local_path: home.as_path().join("config.json") });
let vocab_resource = Resource::Local(LocalResource { local_path: home.as_path().join("vocab.txt") });
let weights_resource = Resource::Local(LocalResource { local_path: home.as_path().join("model.ot") });
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;

View File

@ -13,22 +13,17 @@
// limitations under the License.
use rust_bert::resources::{LocalResource, Resource, download_resource};
use std::path::PathBuf;
use rust_bert::electra::electra::{ElectraConfig, ElectraForMaskedLM};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad};
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("electra-generator");
let config_resource = Resource::Local(LocalResource {local_path: home.as_path().join("config.json")});
let vocab_resource = Resource::Local(LocalResource {local_path: home.as_path().join("vocab.txt")});
let weights_resource = Resource::Local(LocalResource {local_path: home.as_path().join("model.ot")});
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;

View File

@ -22,6 +22,36 @@ use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;
/// # Electra Pretrained model weight files
pub struct ElectraModelResources;
/// # Electra Pretrained model config files
pub struct ElectraConfigResources;
/// # Electra Pretrained model vocab files
pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
}
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
}
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
}
#[derive(Debug, Serialize, Deserialize)]
/// # Electra model configuration
/// Defines the Electra model architecture (e.g. number of layers, hidden layer size, label mapping...)

8
tests/electra.rs Normal file
View File

@ -0,0 +1,8 @@
//#[test]
//fn electra_masked_lm() -> failure::Fallible<()> {
//
//}

View File

@ -32,7 +32,6 @@ weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
print(k)
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)