Updated clippy settings and bart download script

This commit is contained in:
guillaume-be 2020-09-14 06:48:20 +02:00
parent 44ea6dda25
commit 882ec1744b
3 changed files with 14 additions and 14 deletions

View File

@ -10,7 +10,7 @@ jobs:
- before_script: - before_script:
- rustup component add clippy - rustup component add clippy
script: script:
- cargo clippy --all-targets --all-features -- -D warnings - cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern
- script: - script:
- cargo build --verbose - cargo build --verbose
- os: - os:

View File

@ -141,7 +141,7 @@ impl Attention {
) -> (Tensor, Option<Tensor>) { ) -> (Tensor, Option<Tensor>) {
let mut w = query.matmul(&key); let mut w = query.matmul(&key);
if self.scale { if self.scale {
w /= (*value.size().last().unwrap() as f64).sqrt(); w = w / (*value.size().last().unwrap() as f64).sqrt();
} }
let (nd, ns) = (w.size()[2], w.size()[3]); let (nd, ns) = (w.size()[2], w.size()[3]);
@ -149,7 +149,7 @@ impl Attention {
let mut w: Tensor = w * &b + 1e4 * (&b - 1); let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask { if let Some(mask) = attention_mask {
w += mask; w = w + mask;
} }
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train); w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&value); let output = w.matmul(&value);

View File

@ -39,14 +39,14 @@ for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias") k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy()) nps[k] = np.ascontiguousarray(v.cpu().numpy())
# np.savez(target_path / 'model.npz', **nps) np.savez(target_path / 'model.npz', **nps)
#
# source = str(target_path / 'model.npz') source = str(target_path / 'model.npz')
# target = str(target_path / 'model.ot') target = str(target_path / 'model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
#
# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target]) subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
#
# os.remove(str(target_path / 'model.bin')) os.remove(str(target_path / 'model.bin'))
# os.remove(str(target_path / 'model.npz')) os.remove(str(target_path / 'model.npz'))