MBart validation, weights updated

This commit is contained in:
Guillaume B 2021-06-05 11:47:56 +02:00
parent 171c2c8634
commit b6db7cacfb
4 changed files with 16 additions and 14 deletions

View File

@ -199,7 +199,7 @@ impl MBartDecoder {
vec![config.d_model],
Default::default(),
);
let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default());
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
let embed_positions = MBartLearnedPositionalEmbedding::new(
p / "embed_positions",

View File

@ -157,7 +157,7 @@ impl MBartEncoder {
Default::default(),
);
let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default());
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
let embed_positions = MBartLearnedPositionalEmbedding::new(
p / "embed_positions",

View File

@ -579,7 +579,7 @@ impl MBartForSequenceClassification {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::mbart::{MBartConfig, MBartForConditionalGeneration};
/// use rust_bert::mbart::{MBartConfig, MBartForSequenceClassification};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
@ -594,12 +594,11 @@ impl MBartForSequenceClassification {
///
/// let model_output = no_grad(|| {
/// mbart_model
/// .forward_t(Some(&input_tensor),
/// .forward_t(&input_tensor,
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```

View File

@ -3,6 +3,7 @@ import numpy as np
import torch
import subprocess
import argparse
import sys
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -17,14 +18,16 @@ if __name__ == "__main__":
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(k + str(sys.getsizeof(nps[k])))
np.savez(target_folder / 'model.npz', **nps)
source = str(target_folder / 'model.npz')
target = str(target_folder / 'rust_model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.run(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
)
# source = str(target_folder / 'model.npz')
# target = str(target_folder / 'rust_model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
# subprocess.run(
# ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
# )