flax2pt script

This commit is contained in:
arampacha 2021-07-16 10:34:11 +00:00
parent f49fda52e2
commit 2a8b8b785d

14
flax2pt.py Executable file
View File

@ -0,0 +1,14 @@
#!/usr/bin/env python
import argparse
from transformers import AutoModel
parser = argparse.ArgumentParser()
parser.add_argument("model_dir", help="Path to directory containing config.json and flax_model.msgpack")
args = parser.parse_args()
print(f"Loading flax model from {args.model_dir}...")
model = AutoModel.from_pretrained(args.model_dir, from_flax=True)
print(f"Saving pytorch model...")
model.save_pretrained(args.model_dir, save_config=False)