mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
MBart validation, weights updated
This commit is contained in:
parent
171c2c8634
commit
b6db7cacfb
@ -199,7 +199,7 @@ impl MBartDecoder {
|
||||
vec![config.d_model],
|
||||
Default::default(),
|
||||
);
|
||||
let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default());
|
||||
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
|
||||
|
||||
let embed_positions = MBartLearnedPositionalEmbedding::new(
|
||||
p / "embed_positions",
|
||||
|
@ -157,7 +157,7 @@ impl MBartEncoder {
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default());
|
||||
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
|
||||
|
||||
let embed_positions = MBartLearnedPositionalEmbedding::new(
|
||||
p / "embed_positions",
|
||||
|
@ -579,7 +579,7 @@ impl MBartForSequenceClassification {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::mbart::{MBartConfig, MBartForConditionalGeneration};
|
||||
/// use rust_bert::mbart::{MBartConfig, MBartForSequenceClassification};
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
/// # let device = Device::Cpu;
|
||||
@ -594,12 +594,11 @@ impl MBartForSequenceClassification {
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// mbart_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// .forward_t(&input_tensor,
|
||||
/// Some(&encoder_attention_mask),
|
||||
/// None,
|
||||
/// Some(&target_tensor),
|
||||
/// Some(&decoder_attention_mask),
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
/// ```
|
||||
|
@ -3,6 +3,7 @@ import numpy as np
|
||||
import torch
|
||||
import subprocess
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -17,14 +18,16 @@ if __name__ == "__main__":
|
||||
nps = {}
|
||||
for k, v in weights.items():
|
||||
k = k.replace("gamma", "weight").replace("beta", "bias")
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy())
|
||||
|
||||
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
|
||||
continue
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
|
||||
print(k + str(sys.getsizeof(nps[k])))
|
||||
np.savez(target_folder / 'model.npz', **nps)
|
||||
|
||||
source = str(target_folder / 'model.npz')
|
||||
target = str(target_folder / 'rust_model.ot')
|
||||
|
||||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||
subprocess.run(
|
||||
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
|
||||
)
|
||||
# source = str(target_folder / 'model.npz')
|
||||
# target = str(target_folder / 'rust_model.ot')
|
||||
#
|
||||
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||
# subprocess.run(
|
||||
# ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target],
|
||||
# )
|
||||
|
Loading…
Reference in New Issue
Block a user