mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-05 16:47:24 +03:00
Updated clippy settings and bart download script
This commit is contained in:
parent
44ea6dda25
commit
882ec1744b
@ -10,7 +10,7 @@ jobs:
|
|||||||
- before_script:
|
- before_script:
|
||||||
- rustup component add clippy
|
- rustup component add clippy
|
||||||
script:
|
script:
|
||||||
- cargo clippy --all-targets --all-features -- -D warnings
|
- cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern
|
||||||
- script:
|
- script:
|
||||||
- cargo build --verbose
|
- cargo build --verbose
|
||||||
- os:
|
- os:
|
||||||
|
@ -141,7 +141,7 @@ impl Attention {
|
|||||||
) -> (Tensor, Option<Tensor>) {
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let mut w = query.matmul(&key);
|
let mut w = query.matmul(&key);
|
||||||
if self.scale {
|
if self.scale {
|
||||||
w /= (*value.size().last().unwrap() as f64).sqrt();
|
w = w / (*value.size().last().unwrap() as f64).sqrt();
|
||||||
}
|
}
|
||||||
|
|
||||||
let (nd, ns) = (w.size()[2], w.size()[3]);
|
let (nd, ns) = (w.size()[2], w.size()[3]);
|
||||||
@ -149,7 +149,7 @@ impl Attention {
|
|||||||
|
|
||||||
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
||||||
if let Some(mask) = attention_mask {
|
if let Some(mask) = attention_mask {
|
||||||
w += mask;
|
w = w + mask;
|
||||||
}
|
}
|
||||||
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
||||||
let output = w.matmul(&value);
|
let output = w.matmul(&value);
|
||||||
|
@ -39,14 +39,14 @@ for k, v in weights.items():
|
|||||||
k = k.replace("gamma", "weight").replace("beta", "bias")
|
k = k.replace("gamma", "weight").replace("beta", "bias")
|
||||||
nps[k] = np.ascontiguousarray(v.cpu().numpy())
|
nps[k] = np.ascontiguousarray(v.cpu().numpy())
|
||||||
|
|
||||||
# np.savez(target_path / 'model.npz', **nps)
|
np.savez(target_path / 'model.npz', **nps)
|
||||||
#
|
|
||||||
# source = str(target_path / 'model.npz')
|
source = str(target_path / 'model.npz')
|
||||||
# target = str(target_path / 'model.ot')
|
target = str(target_path / 'model.ot')
|
||||||
#
|
|
||||||
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||||
#
|
|
||||||
# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
||||||
#
|
|
||||||
# os.remove(str(target_path / 'model.bin'))
|
os.remove(str(target_path / 'model.bin'))
|
||||||
# os.remove(str(target_path / 'model.npz'))
|
os.remove(str(target_path / 'model.npz'))
|
||||||
|
Loading…
Reference in New Issue
Block a user