Updated model weight download

This commit is contained in:
Guillaume B 2020-04-24 17:04:28 +02:00
parent 1451a64b89
commit c65943ab2f
16 changed files with 43 additions and 10 deletions

View File

@ -32,10 +32,10 @@ features = [ "doc-only" ]
[dependencies]
rust_tokenizers = "2.0.4"
tch = "0.1.6"
serde_json = "1.0.45"
serde = {version = "1.0.104", features = ["derive"]}
failure = "0.1.6"
dirs = "2.0"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}
failure = "0.1.7"
dirs = "2.0.0"
itertools = "0.9.0"
ordered-float = "1.0.2"
csv = "1.1.3"

View File

@ -23,7 +23,7 @@ fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
home.push("distilbert-sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

View File

@ -24,7 +24,7 @@ fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
home.push("distilbert-sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

View File

@ -1,2 +1,2 @@
torch == 1.4.0
transformers == 2.6.0
torch == 1.5.0
transformers == 2.8.0

View File

@ -17,7 +17,7 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
home.push("distilbert-sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

View File

@ -48,3 +48,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -48,3 +48,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -43,3 +43,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -44,3 +44,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -42,3 +42,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -42,3 +42,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -46,3 +46,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -48,3 +48,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -48,3 +48,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -47,3 +47,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

View File

@ -12,7 +12,7 @@ config_path = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP["distilbert-base-uncased-
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilbert-base-uncased"]
weights_path = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP["distilbert-base-uncased-finetuned-sst-2-english"]
target_path = Path.home() / 'rustbert' / 'distilbert_sst2'
target_path = Path.home() / 'rustbert' / 'distilbert-sst2'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
@ -42,3 +42,6 @@ toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve(
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))