2020-02-11 21:16:35 +03:00
|
|
|
from transformers import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
|
|
from transformers.tokenization_distilbert import PRETRAINED_VOCAB_FILES_MAP
|
|
|
|
from transformers.file_utils import get_from_cache
|
|
|
|
from pathlib import Path
|
|
|
|
import shutil
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import subprocess
|
|
|
|
|
2020-02-14 00:19:03 +03:00
|
|
|
config_path = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP["distilbert-base-uncased"]
|
2020-02-11 21:16:35 +03:00
|
|
|
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilbert-base-uncased"]
|
2020-02-14 00:19:03 +03:00
|
|
|
weights_path = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP["distilbert-base-uncased"]
|
2020-02-11 21:16:35 +03:00
|
|
|
|
2020-02-14 00:19:03 +03:00
|
|
|
target_path = Path.home() / 'rustbert' / 'distilbert'
|
2020-02-11 21:16:35 +03:00
|
|
|
|
|
|
|
temp_config = get_from_cache(config_path)
|
|
|
|
temp_vocab = get_from_cache(vocab_path)
|
|
|
|
temp_weights = get_from_cache(weights_path)
|
|
|
|
|
2020-02-16 13:37:08 +03:00
|
|
|
os.makedirs(str(target_path), exist_ok=True)
|
2020-02-11 21:16:35 +03:00
|
|
|
|
2020-02-16 13:37:08 +03:00
|
|
|
config_path = str(target_path / 'config.json')
|
|
|
|
vocab_path = str(target_path / 'vocab.txt')
|
|
|
|
model_path = str(target_path / 'model.bin')
|
|
|
|
|
|
|
|
shutil.copy(temp_config, config_path)
|
|
|
|
shutil.copy(temp_vocab, vocab_path)
|
|
|
|
shutil.copy(temp_weights, model_path)
|
|
|
|
|
|
|
|
weights = torch.load(temp_weights, map_location='cpu')
|
2020-02-11 21:16:35 +03:00
|
|
|
nps = {}
|
|
|
|
for k, v in weights.items():
|
2020-02-16 16:29:35 +03:00
|
|
|
nps[k] = np.ascontiguousarray(v.cpu().numpy())
|
2020-02-11 21:16:35 +03:00
|
|
|
|
|
|
|
np.savez(target_path / 'model.npz', **nps)
|
|
|
|
|
|
|
|
source = str(target_path / 'model.npz')
|
|
|
|
target = str(target_path / 'model.ot')
|
|
|
|
|
|
|
|
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
|
|
|
|
|
|
|
subprocess.call(
|
2020-02-16 13:37:08 +03:00
|
|
|
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
2020-04-24 18:04:28 +03:00
|
|
|
|
|
|
|
os.remove(str(target_path / 'model.bin'))
|
|
|
|
os.remove(str(target_path / 'model.npz'))
|