diff --git a/.travis.yml b/.travis.yml index ecd34bc..9d09914 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,7 @@ jobs: - before_script: - rustup component add clippy script: - - cargo clippy --all-targets --all-features -- -D warnings + - cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern - script: - cargo build --verbose - os: diff --git a/src/gpt2/attention.rs b/src/gpt2/attention.rs index 2807de1..4760cfc 100644 --- a/src/gpt2/attention.rs +++ b/src/gpt2/attention.rs @@ -141,7 +141,7 @@ impl Attention { ) -> (Tensor, Option) { let mut w = query.matmul(&key); if self.scale { - w /= (*value.size().last().unwrap() as f64).sqrt(); + w = w / (*value.size().last().unwrap() as f64).sqrt(); } let (nd, ns) = (w.size()[2], w.size()[3]); @@ -149,7 +149,7 @@ impl Attention { let mut w: Tensor = w * &b + 1e4 * (&b - 1); if let Some(mask) = attention_mask { - w += mask; + w = w + mask; } w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train); let output = w.matmul(&value); diff --git a/utils/download-dependencies_bart_large_mnli.py b/utils/download-dependencies_bart_large_mnli.py index cb814e1..e0de392 100644 --- a/utils/download-dependencies_bart_large_mnli.py +++ b/utils/download-dependencies_bart_large_mnli.py @@ -39,14 +39,14 @@ for k, v in weights.items(): k = k.replace("gamma", "weight").replace("beta", "bias") nps[k] = np.ascontiguousarray(v.cpu().numpy()) -# np.savez(target_path / 'model.npz', **nps) -# -# source = str(target_path / 'model.npz') -# target = str(target_path / 'model.ot') -# -# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() -# -# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target]) -# -# os.remove(str(target_path / 'model.bin')) -# os.remove(str(target_path / 'model.npz')) +np.savez(target_path / 'model.npz', **nps) + +source = str(target_path / 'model.npz') +target = str(target_path / 'model.ot') + +toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() + +subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target]) + +os.remove(str(target_path / 'model.bin')) +os.remove(str(target_path / 'model.npz'))