mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-07-07 08:56:34 +03:00
Updated clippy settings and bart download script
This commit is contained in:
parent
44ea6dda25
commit
882ec1744b
|
@ -10,7 +10,7 @@ jobs:
|
|||
- before_script:
|
||||
- rustup component add clippy
|
||||
script:
|
||||
- cargo clippy --all-targets --all-features -- -D warnings
|
||||
- cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern
|
||||
- script:
|
||||
- cargo build --verbose
|
||||
- os:
|
||||
|
|
|
@ -141,7 +141,7 @@ impl Attention {
|
|||
) -> (Tensor, Option<Tensor>) {
|
||||
let mut w = query.matmul(&key);
|
||||
if self.scale {
|
||||
w /= (*value.size().last().unwrap() as f64).sqrt();
|
||||
w = w / (*value.size().last().unwrap() as f64).sqrt();
|
||||
}
|
||||
|
||||
let (nd, ns) = (w.size()[2], w.size()[3]);
|
||||
|
@ -149,7 +149,7 @@ impl Attention {
|
|||
|
||||
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
||||
if let Some(mask) = attention_mask {
|
||||
w += mask;
|
||||
w = w + mask;
|
||||
}
|
||||
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
||||
let output = w.matmul(&value);
|
||||
|
|
|
@ -39,14 +39,14 @@ for k, v in weights.items():
|
|||
k = k.replace("gamma", "weight").replace("beta", "bias")
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy())
|
||||
|
||||
# np.savez(target_path / 'model.npz', **nps)
|
||||
#
|
||||
# source = str(target_path / 'model.npz')
|
||||
# target = str(target_path / 'model.ot')
|
||||
#
|
||||
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||
#
|
||||
# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
||||
#
|
||||
# os.remove(str(target_path / 'model.bin'))
|
||||
# os.remove(str(target_path / 'model.npz'))
|
||||
np.savez(target_path / 'model.npz', **nps)
|
||||
|
||||
source = str(target_path / 'model.npz')
|
||||
target = str(target_path / 'model.ot')
|
||||
|
||||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||
|
||||
subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
||||
|
||||
os.remove(str(target_path / 'model.bin'))
|
||||
os.remove(str(target_path / 'model.npz'))
|
||||
|
|
Loading…
Reference in New Issue
Block a user