Definition and loading of Reformer config

This commit is contained in:
Guillaume B 2020-11-01 09:40:53 +01:00
parent 4e65a553ec
commit ed346d34ac
6 changed files with 179 additions and 1 deletions

View File

@ -53,7 +53,7 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~6.0.0"
rust_tokenizers = { version = "~6.1.0", path = "E:/Coding/backup-rust/rust-tokenizers/main" }
tch = "~0.2.1"
serde_json = "1.0.59"
serde = { version = "1.0.117", features = ["derive"] }

49
examples/reformer.rs Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::reformer::{
ReformerConfig, ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::ReformerTokenizer;
use tch::{nn, Device};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let _weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut _vs = nn::VarStore::new(device);
let _tokenizer = ReformerTokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
let _config = ReformerConfig::from_file(config_path);
// let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config);
// vs.load(weights_path)?;
Ok(())
}

View File

@ -68,6 +68,7 @@ pub mod gpt2;
pub mod marian;
pub mod openai_gpt;
pub mod pipelines;
pub mod reformer;
pub mod roberta;
pub mod t5;
pub mod xlnet;

24
src/reformer/attention.rs Normal file
View File

@ -0,0 +1,24 @@
// Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
/// # Attention type for the model (local or LSH)
pub enum AttentionType {
/// Local attention
local,
/// LSH attention
lsh,
}

6
src/reformer/mod.rs Normal file
View File

@ -0,0 +1,6 @@
mod attention;
mod reformer_model;
pub use reformer_model::{
ReformerConfig, ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
};

View File

@ -0,0 +1,98 @@
// Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::Activation;
use crate::reformer::attention::AttentionType;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// # Reformer Pretrained model weight files
pub struct ReformerModelResources;
/// # Reformer Pretrained model config files
pub struct ReformerConfigResources;
/// # Reformer Pretrained model vocab files
pub struct ReformerVocabResources;
impl ReformerModelResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"xlnet-base-cased/model",
"https://cdn.huggingface.co/google/reformer-crime-and-punishment/rust_model.ot",
);
}
impl ReformerConfigResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"xlnet-base-cased/config",
"https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json",
);
}
impl ReformerVocabResources {
/// Shared under Apache 2.0 license by the Trax Authors at https://github.com/google/trax/tree/master/trax/models/reformer. Modified with conversion to C-array format.
pub const CRIME_AND_PUNISHMENT: (&'static str, &'static str) = (
"xlnet-base-cased/spiece",
"https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model",
);
}
#[derive(Debug, Serialize, Deserialize)]
/// # Reformer model configuration
/// Defines the Reformer model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct ReformerConfig {
pub attention_head_size: i64,
pub attention_probs_dropout_prob: f64,
pub attn_layers: Vec<AttentionType>,
pub axial_norm_std: f64,
pub axial_pos_embds: bool,
pub axial_pos_embds_dim: Vec<i64>,
pub axial_pos_shape: Vec<i64>,
pub chunk_size_lm_head: i64,
pub chunk_size_feed_forward: Option<i64>,
pub eos_token_id: i64,
pub pad_token_id: i64,
pub feed_forward_size: i64,
pub hash_seed: Option<i64>,
pub hidden_act: Activation,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: Option<f64>,
pub intermediate_size: i64,
pub is_decoder: bool,
pub layer_norm_eps: Option<f64>,
pub local_attn_chunk_length: i64,
pub lsh_attn_chunk_length: i64,
pub max_position_embeddings: i64,
pub vocab_size: i64,
pub num_attention_heads: i64,
pub num_buckets: Vec<i64>,
pub local_num_chunks_after: Option<i64>,
pub local_num_chunks_before: Option<i64>,
pub local_attention_probs_dropout_prob: Option<f64>,
pub lsh_num_chunks_after: Option<i64>,
pub lsh_num_chunks_before: Option<i64>,
pub lsh_attention_probs_dropout_prob: Option<f64>,
pub num_hashes: i64,
pub num_hidden_layers: i64,
pub use_cache: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}
impl Config<ReformerConfig> for ReformerConfig {}