diff --git a/src/mbart/decoder.rs b/src/mbart/decoder.rs index 7d7dace..3013ea2 100644 --- a/src/mbart/decoder.rs +++ b/src/mbart/decoder.rs @@ -199,7 +199,7 @@ impl MBartDecoder { vec![config.d_model], Default::default(), ); - let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default()); + let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default()); let embed_positions = MBartLearnedPositionalEmbedding::new( p / "embed_positions", diff --git a/src/mbart/encoder.rs b/src/mbart/encoder.rs index f270323..120cab3 100644 --- a/src/mbart/encoder.rs +++ b/src/mbart/encoder.rs @@ -157,7 +157,7 @@ impl MBartEncoder { Default::default(), ); - let layer_norm = nn::layer_norm(p / "layernorm", vec![config.d_model], Default::default()); + let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default()); let embed_positions = MBartLearnedPositionalEmbedding::new( p / "embed_positions", diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index fca0177..fbc3e03 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -579,7 +579,7 @@ impl MBartForSequenceClassification { /// # use rust_bert::Config; /// # use std::path::Path; /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::mbart::{MBartConfig, MBartForConditionalGeneration}; + /// use rust_bert::mbart::{MBartConfig, MBartForSequenceClassification}; /// # let config_path = Path::new("path/to/config.json"); /// # let vocab_path = Path::new("path/to/vocab.txt"); /// # let device = Device::Cpu; @@ -594,12 +594,11 @@ impl MBartForSequenceClassification { /// /// let model_output = no_grad(|| { /// mbart_model - /// .forward_t(Some(&input_tensor), + /// .forward_t(&input_tensor, /// Some(&encoder_attention_mask), /// None, /// Some(&target_tensor), /// Some(&decoder_attention_mask), - /// None, /// false) /// }); /// ``` diff --git a/utils/convert_model.py b/utils/convert_model.py index 0ee3c91..0154e9d 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -3,6 +3,7 @@ import numpy as np import torch import subprocess import argparse +import sys if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -17,14 +18,16 @@ if __name__ == "__main__": nps = {} for k, v in weights.items(): k = k.replace("gamma", "weight").replace("beta", "bias") - nps[k] = np.ascontiguousarray(v.cpu().numpy()) - + if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}: + continue + nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32)) + print(k + str(sys.getsizeof(nps[k]))) np.savez(target_folder / 'model.npz', **nps) - source = str(target_folder / 'model.npz') - target = str(target_folder / 'rust_model.ot') - - toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() - subprocess.run( - ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target], - ) + # source = str(target_folder / 'model.npz') + # target = str(target_folder / 'rust_model.ot') + # + # toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() + # subprocess.run( + # ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target], + # )